Pytorch DDP 在获得自由端口时陷入困境

问题描述

我尝试在 PyTorch 的 DDP 初始化中获得一个空闲端口。但是,我的代码卡住了。以下片段可能会重复我的描述:

def get_open_port():
    with closing(socket.socket(socket.AF_INET,socket.soCK_STREAM)) as s:
        s.bind(('',0))
        s.setsockopt(socket.soL_SOCKET,socket.so_REUSEADDR,1)
        return s.getsockname()[1]

def setup(rank,world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    port = get_open_port()
    os.environ['MASTER_PORT'] = str(port)   # '12345'

    # Initialize the process group.
    dist.init_process_group('Nccl',rank=rank,world_size=world_size)

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel,self).__init__()
        self.net1 = nn.Linear(10,5)

    def forward(self,x):
        print(f'x device={x.device}')
        return self.net1(x)


def demo_basic(rank,world_size):
    setup(rank,world_size)

    logger = logging.getLogger('train')
    logger.setLevel(logging.DEBUG)
    logger.info(f'Running DPP on rank={rank}.')

    # Create model and move it to GPU.
    model = ToyModel().to(rank)
    ddp_model = DDP(model,device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(),lr=0.001)  # optimizer takes DDP model.

    optimizer.zero_grad()
    inputs = torch.randn(20,10)  # .to(rank)

    print(f'inputs device={inputs.device}')
    outputs = ddp_model(inputs)
    print(f'output device={outputs.device}')

    labels = torch.randn(20,5).to(rank)
    loss_fn(outputs,labels).backward()

    optimizer.step()

    cleanup()


def run_demo(demo_func,world_size):
    mp.spawn(
        demo_func,args=(world_size,),nprocs=world_size,join=True
    )

run_demo(demo_basic,4)

函数 get_open_port 应该在调用后释放端口。我的问题是: 1. 它是如何发生的? 2. 如何修复?

解决方法

答案来自here。详细的答案是: 1. 由于每个空闲端口都是由单个进程生成的,所以端口最终是不同的; 2.我们可以在开始时获得一个空闲端口并将其传递给进程。

更正后的片段:

def get_open_port():
    with closing(socket.socket(socket.AF_INET,socket.SOCK_STREAM)) as s:
        s.bind(('',0))
        s.setsockopt(socket.SOL_SOCKET,socket.SO_REUSEADDR,1)
        return s.getsockname()[1]


def setup(rank,world_size,port):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = str(port)

    # Initialize the process group.
    dist.init_process_group('NCCL',rank=rank,world_size=world_size)


def cleanup():
    dist.destroy_process_group()


class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel,self).__init__()
        self.net1 = nn.Linear(10,5)

    def forward(self,x):
        print(f'x device={x.device}')
        # return self.net2(self.relu(self.net1(x)))
        return self.net1(x)


def demo_basic(rank,free_port):
    setup(rank,free_port)

    logger = logging.getLogger('train')
    logger.setLevel(logging.DEBUG)
    logger.info(f'Running DPP on rank={rank}.')

    # Create model and move it to GPU.
    model = ToyModel().to(rank)
    ddp_model = DDP(model,device_ids=[rank])

    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(),lr=0.001)  # optimizer takes DDP model.

    optimizer.zero_grad()
    inputs = torch.randn(20,10)  # .to(rank)

    print(f'inputs device={inputs.device}')
    outputs = ddp_model(inputs)
    print(f'output device={outputs.device}')

    labels = torch.randn(20,5).to(rank)
    loss_fn(outputs,labels).backward()

    optimizer.step()

    cleanup()


def run_demo(demo_func,free_port):
    mp.spawn(
        demo_func,args=(world_size,free_port),nprocs=world_size,join=True
    )

free_port = get_open_port()
run_demo(demo_basic,4,free_port)