#!/usr/bin/env python3

# nccl-test.py
print("nccl-test version 5.2")
import os
import time
import sys
import socket
import time

if socket.gethostbyname(socket.gethostname()) == "127.0.0.1":
    print("torch.distributed currently will not function properly if `hostname` resolves to 127.0.0.1")
    sys.exit(1)

print("start of import torch")
import torch
import torch.distributed as dist
print("end of import torch")

print("start of torch.cuda.is_available()")
if not torch.cuda.is_available():
    print("torch.cuda.is_available() is not working")
    exit(1)
print("end of torch.cuda.is_available()")

print("start of dist.init_process_group('nccl')")
dist.init_process_group("nccl")
print("end of dist.init_process_group('nccl')")

local_rank = int(os.environ["LOCAL_RANK"])
rank = dist.get_rank()
world_size = dist.get_world_size()
torch.cuda.set_device(local_rank)

x = torch.ones(200,100).to(f"cuda:{local_rank}")
print("starting element-wise add")
x = x + x
print("end of element-wise add")

print("starting all_reduce")
dist.all_reduce(x)
print("end of all_reduce")

print("start of barrier")
dist.barrier()
print("end of barrier")

if local_rank == 0:
    print(f"NCCL WORKS!")
    print("SUCCESSFULLY EXITING SCRIPT")