Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions test/float8/test_fsdp_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def setup(rank, world_size):
os.environ["MASTER_PORT"] = "12355"

# initialize the process group
dist.init_process_group("nccl", rank=rank, world_size=world_size)
device_type = torch.accelerator.current_accelerator().type
backend = dist.get_default_backend_for_device(device_type)
dist.init_process_group(backend, rank=rank, world_size=world_size)


def cleanup():
Expand Down Expand Up @@ -72,7 +74,9 @@ def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32):
# and modified
def fsdp_main(rank, world_size, args):
setup(rank, world_size)
torch.cuda.set_device(rank)
device_type = torch.accelerator.current_accelerator().type
torch.accelerator.set_device_index(rank)
device = f"{device_type}:{rank}"

(emulate,) = args

Expand All @@ -81,20 +85,20 @@ def fsdp_main(rank, world_size, args):
# regardless of float8.

model = get_model(K, N, is_fp8=True, emulate=emulate, base_dtype=torch.bfloat16).to(
rank
device
)

# To compile FSDP, we need use_orig_params to True
model = FSDP(model, use_orig_params=True)

optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size)
input_local = torch.randn(B, M, K, N, device="cuda")
input_local = torch.randn(B, M, K, N, device=device)

model = torch.compile(model)

for _iter in range(N_ITER):
optimizer.zero_grad()
with torch.autocast("cuda"):
with torch.autocast(device_type):
y_local = model(input_local)
y_local.sum().backward()
optimizer.step()
Expand All @@ -105,17 +109,17 @@ def fsdp_main(rank, world_size, args):

def run():
emulate = False
if not torch.cuda.is_available():
warnings.warn("CUDA not available, running in emulation_mode", stacklevel=2)
if not torch.accelerator.is_available():
warnings.warn("GPU not available, running in emulation_mode", stacklevel=2)
emulate = True
elif torch.cuda.get_device_capability() < (9, 0):
elif torch.cuda.is_available() and torch.cuda.get_device_capability() < (9, 0):
warnings.warn(
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode",
stacklevel=2,
)
emulate = True

WORLD_SIZE = torch.cuda.device_count()
WORLD_SIZE = torch.accelerator.device_count()
args = (emulate,)
mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True)

Expand Down
6 changes: 3 additions & 3 deletions test/float8/test_fsdp_compile.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

# terminate script on first error
set -e
if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False"; then
echo "Skipping test_fsdp_compile.sh because no CUDA devices are available."
if python -c 'import torch;print(torch.accelerator.is_available())' | grep -q "False"; then
echo "Skipping test_fsdp_compile.sh because no accelerator devices are available."
exit
fi

# Code to be executed if CUDA devices are available
NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/float8/test_fsdp_compile.py
CCL_LOG_LEVEL=info NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/float8/test_fsdp_compile.py
Loading