PySyft工人过度拟合

问题描述

我尝试用pysyft训练图像分类(cifar10)。我的培训设置中有10个工作人员,每个工作人员都能获得800和1200个数据集图像。

我的问题是,在大约250-300个时期之后,火车的损耗大约为0.005,尽管测试精度大约为45%,损耗增加了1.5-> 8.5,但模型却停止了改进。 我在100个图片上尝试了100个工人,结果停在32%。 此外,该实现是模型与FL Frameworks比较的一部分,因此不能更改模型,并且数据将本地加载并转换为DataLoader。 因此,我对Pytorch和PySyft完全没有经验,尽管我尝试尽可能地与示例保持一致,但是在训练模型时可能会犯一些错误

我在没有PySyft的情况下训练了模型,并且模型达到了约85%,所以我认为我的数据加载器和模型应该不是问题。对我来说,在培训过程中,工人似乎过度拟合自己的数据。

是否有办法防止工人(而不是工人)过度拟合或为全局模型计算损失?

教练:

    
def fl_train(args,model,device,federated_train_loader,optimizer,epoch,log):
    model.train()
    results = []
    metrics = []
    t1 = time.time()
    cel = nn.CrossEntropyLoss()
    for batch_idx,(data,target) in enumerate(federated_train_loader): # <-- Now it is a distributed dataset
        t2 = time.time()
        model.send(data.location) # <-- NEW: send the model to the right location
        data,target = data.to(device),target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output,target.long())
        loss.backward()
        optimizer.step()
        model.get() # <-- NEW: get the model back
        if batch_idx % args.log_interval == 0:
            loss = loss.get() # <-- NEW: get the loss back
            results.append(loss.item())
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch,batch_idx * BATCH_SIZE,len(federated_train_loader) * BATCH_SIZE,100. * batch_idx / len(federated_train_loader),loss.item()))

型号:

class CNN(nn.Module):

    def __init__(self):
        super(CNN,self).__init__()

        self.conv_layer = nn.Sequential(

            # Conv Layer block 1
            nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3),nn.ReLU(inplace=True),nn.MaxPool2d((2,2)),# Conv Layer block 2
            nn.Conv2d(in_channels=32,out_channels=64,# Conv Layer block 3
            nn.Conv2d(in_channels=64,)

        self.fc_layer = nn.Sequential(
            nn.Linear(1024,64),nn.Linear(64,10)
        )


    def forward(self,x):
        # CNN layers
        x = self.conv_layer(x)

        # flatten
        x = x.view(-1,1024)

        # NN layer
        x = self.fc_layer(x)
        return F.log_softmax(x,dim=1)

主要:

model = CNN().to(device)
optimizer = optim.SGD(model.parameters(),lr=0.02) # Todo momentum is not supported at the moment
log = {}
for epoch in range(1,args.epochs + 1):
    log = fl_train(args,f_DataLoader,log)
    if epoch % 20 == 0:
      log = test(args,test_loader,log)
    if epoch % 100 == 0:
      store_results(log,model)

日志:

....
Train Epoch: 317 [0/10400 (0%)] Loss: 0.005194
Train Epoch: 317 [3000/10400 (29%)] Loss: 0.003882
Train Epoch: 317 [6000/10400 (58%)] Loss: 0.003100
Train Epoch: 317 [9000/10400 (87%)] Loss: 0.004298
Train Epoch: 318 [0/10400 (0%)] Loss: 0.007426
Train Epoch: 318 [3000/10400 (29%)] Loss: 0.002255
Train Epoch: 318 [6000/10400 (58%)] Loss: 0.003835
Train Epoch: 318 [9000/10400 (87%)] Loss: 0.005277
Train Epoch: 319 [0/10400 (0%)] Loss: 0.006207
Train Epoch: 319 [3000/10400 (29%)] Loss: 0.003562
Train Epoch: 319 [6000/10400 (58%)] Loss: 0.001904
Train Epoch: 319 [9000/10400 (87%)] Loss: 0.002644
Train Epoch: 320 [0/10400 (0%)] Loss: 0.007491
Train Epoch: 320 [3000/10400 (29%)] Loss: 0.003794
Train Epoch: 320 [6000/10400 (58%)] Loss: 0.002643
Train Epoch: 320 [9000/10400 (87%)] Loss: 0.002981
Test set: Average loss: 9.1279,Accuracy: 458/1000 (46%)

Train Epoch: 321 [0/10400 (0%)] Loss: 0.007153
Train Epoch: 321 [3000/10400 (29%)] Loss: 0.004265
Train Epoch: 321 [6000/10400 (58%)] Loss: 0.002708
Train Epoch: 321 [9000/10400 (87%)] Loss: 0.002518
Train Epoch: 322 [0/10400 (0%)] Loss: 0.006285
Train Epoch: 322 [3000/10400 (29%)] Loss: 0.002357
Train Epoch: 322 [6000/10400 (58%)] Loss: 0.002465
Train Epoch: 322 [9000/10400 (87%)] Loss: 0.002406
Train Epoch: 323 [0/10400 (0%)] Loss: 0.005361
Train Epoch: 323 [3000/10400 (29%)] Loss: 0.004807
Train Epoch: 323 [6000/10400 (58%)] Loss: 0.001903
Train Epoch: 323 [9000/10400 (87%)] Loss: 0.003711
Train Epoch: 324 [0/10400 (0%)] Loss: 0.006609
....

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...