在 Google Colab 和本地机器上训练 DeepLab ResNet V3 之间的巨大差异

问题描述

我正在尝试训练 Deeplab resnet V3 以对自定义数据集执行语义分割。我一直在我的本地机器上工作,但我的 GPU 只是一个小型 Quadro T1000,所以我决定将我的模型移到 Google Colab 上,以利用他们的 GPU 实例并获得更好的结果。

虽然我获得了我希望的速度提升,但与我的本地机器相比,我在 colab 上的训练损失大不相同。我已经复制并粘贴了完全相同的代码,所以我能找到的唯一区别是在数据集中。我使用的是完全相同的数据集,除了 colab 上的一个是 Google Drive 上本地数据集的副本。我注意到 Drive 订单文件在 Windows 上有所不同,但我看不出这是一个问题,因为我随机打乱了数据集。我知道这些随机分裂可能会导致输出的微小差异,但是训练损失大约 10 倍的差异是没有意义的。

我也尝试过在 colab 上使用不同的随机种子、不同的批次大小、不同的 train_test_split 参数运行该版本,并将优化器从 SGD 更改为 Adam,但是,这仍然导致模型很早收敛,损失约为0.5.

这是我的代码

import torch
from torch.utils import data
from torchvision import transforms
from customdatasets import SegmentationDataSet
import pathlib
from sklearn.model_selection import train_test_split
from customtransforms import Compose,AlbuSeg2d,DenseTarget
from customtransforms import MoveAxis,normalize01,Resize
import albumentations
import matplotlib.pyplot as plt
import time
import GPUtil



def get_filenames_of_path(path: pathlib.Path,ext: str = '*'):
    """Returns a list of files in a directory/path. Uses pathlib."""
    filenames = [file for file in path.glob(ext) if file.is_file()]
    return filenames


if __name__ == '__main__':

    root = pathlib.Path.cwd() / 'train'
    inputs = get_filenames_of_path(root / 'input')
    targets = get_filenames_of_path(root / 'target')

    

# training transformations and augmentations
    transforms_training = Compose([
        Resize(input_size=(128,128,3),target_size=(128,128)),AlbuSeg2d(albu=albumentations.HorizontalFlip(p=0.5)),MoveAxis(),normalize01()
    ])
# validation transformations
    transforms_validation = Compose([
        Resize(input_size=(128,normalize01()
    ])
    if torch.cuda.is_available():
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    random_seed = 142
    train_size = 0.8

    inputs_train,inputs_valid = train_test_split(
        inputs,random_state=random_seed,train_size=train_size,shuffle=True)
    targets_train,targets_valid = train_test_split(
        targets,shuffle=True)

    dataset_train = SegmentationDataSet(inputs=inputs_train,targets=targets_train,transform=transforms_training,device=device)

    dataset_valid = SegmentationDataSet(inputs=inputs_valid,targets=targets_valid,transform=transforms_validation,device=device)


    DataLoader_training = data.DataLoader(dataset=dataset_train,batch_size=15,shuffle=True,num_workers=4,pin_memory=True)

    DataLoader_validation = data.DataLoader(dataset=dataset_valid,pin_memory=True)


    model = torch.hub.load('pytorch/vision:v0.6.0','deeplabv3_resnet101',pretrained=False)


    criterion = torch.nn.CrossEntropyLoss()

    model = model.to(device)

    optimizer = torch.optim.SGD(model.parameters(),lr=0.001,momentum=0.99)


    epochs = 10
    steps = 0
    running_loss = 0
    print_every = 10
    train_losses,valid_losses = [],[]

    start_time = time.time()
    prev_time = time.time()


    for epoch in range(epochs):
        #Training
        for inputs,labels in DataLoader_training:
            steps += 1
            inputs,labels = inputs.to(device,non_blocking=True),labels.to(device,non_blocking=True)
            optimizer.zero_grad()
            logps = model(inputs)
            loss = criterion(logps['out'],labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            if steps % print_every == 0:
                train_losses.append(running_loss / len(DataLoader_training))
                epoch_time = time.time()
                elasped_time = epoch_time - prev_time
                prev_time = epoch_time
                print(f"Epoch {epoch + 1}/{epochs}.. "
                    f"Train loss: {running_loss / print_every:.3f}.. "
                    f"Elapsed time: {elasped_time}")

                running_loss = 0
                model.train()
        # Evaluation
        valid_loss = 0
        accuracy = 0
        model.eval()
        with torch.no_grad():
            for inputs,labels in DataLoader_validation:
                inputs,non_blocking=True)
                logps = model.forward(inputs)
                batch_loss = criterion(logps['out'],labels)
                valid_loss += batch_loss.item()

                ps = torch.exp(logps['out'])
                top_p,top_class = ps.topk(1,dim=1)
                equals = top_class == labels.view(*top_class.shape)
                accuracy += torch.mean(equals.type(torch.FloatTensor)).item()
        valid_losses.append(valid_loss / len(DataLoader_validation))
        print(f"Epoch {epoch + 1}/{epochs}.. "
            f"Validation loss: {valid_loss / len(DataLoader_training):.3f}.. "
            f"Validation accuracy: {accuracy / len(DataLoader_training):.3f} ")
        model.train()
    torch.save(model,'model.pth')

    end_time = time.time()
    total_time = end_time - start_time
    print("Total Time: ",total_time)
    plt.plot(train_losses,label='Training loss')
    plt.plot(valid_losses,label='Validation loss')
    plt.legend(frameon=False)
    plt.show()

这是 Colab 上一个 epoch 的输出

Epoch 1/10.. Train loss: 2.080.. Elapsed time: 12.156640768051147
Epoch 1/10.. Train loss: 1.231.. Elapsed time: 8.76858925819397
Epoch 1/10.. Train loss: 1.051.. Elapsed time: 8.315532445907593
Epoch 1/10.. Train loss: 0.890.. Elapsed time: 8.249168634414673
Epoch 1/10.. Train loss: 0.839.. Elapsed time: 8.248667478561401
Epoch 1/10.. Train loss: 0.807.. Elapsed time: 8.120820999145508
Epoch 1/10.. Train loss: 0.742.. Elapsed time: 8.298616886138916
Epoch 1/10.. Train loss: 0.726.. Elapsed time: 8.170734167098999
Epoch 1/10.. Train loss: 0.677.. Elapsed time: 8.221246004104614
Epoch 1/10.. Train loss: 0.698.. Elapsed time: 8.124614000320435
Epoch 1/10.. Train loss: 0.675.. Elapsed time: 8.197462558746338
Epoch 1/10.. Train loss: 0.682.. Elapsed time: 8.263437509536743
Epoch 1/10.. Train loss: 0.626.. Elapsed time: 8.156179189682007
Epoch 1/10.. Train loss: 0.632.. Elapsed time: 8.268096446990967
Epoch 1/10.. Train loss: 0.616.. Elapsed time: 8.214547872543335
Epoch 1/10.. Train loss: 0.585.. Elapsed time: 8.31475019454956
Epoch 1/10.. Train loss: 0.598.. Elapsed time: 8.388074398040771
Epoch 1/10.. Train loss: 0.626.. Elapsed time: 8.179292440414429
Epoch 1/10.. Train loss: 0.612.. Elapsed time: 8.252359390258789
Epoch 1/10.. Train loss: 0.592.. Elapsed time: 8.284745693206787
Epoch 1/10.. Train loss: 0.597.. Elapsed time: 8.31213927268982
Epoch 1/10.. Train loss: 0.566.. Elapsed time: 8.164374113082886
Epoch 1/10.. Train loss: 0.556.. Elapsed time: 8.300082206726074
Epoch 1/10.. Train loss: 0.568.. Elapsed time: 8.26304841041565
Epoch 1/10.. Train loss: 0.572.. Elapsed time: 8.309881448745728
Epoch 1/10.. Train loss: 0.586.. Elapsed time: 8.211671352386475
Epoch 1/10.. Train loss: 0.586.. Elapsed time: 8.321797609329224
Epoch 1/10.. Train loss: 0.535.. Elapsed time: 8.318871021270752
Epoch 1/10.. Train loss: 0.543.. Elapsed time: 8.152915239334106
Epoch 1/10.. Train loss: 0.569.. Elapsed time: 8.251380205154419
Epoch 1/10.. Train loss: 0.526.. Elapsed time: 8.29153847694397
Epoch 1/10.. Train loss: 0.565.. Elapsed time: 8.15071702003479
Epoch 1/10.. Train loss: 0.542.. Elapsed time: 8.253364562988281
Epoch 1/10.. Validation loss: 0.182.. Validation accuracy: 0.271 

这是我本地机器上的输出

Epoch 1/10.. Train loss: 2.932.. Elapsed time: 32.148621797561646
Epoch 1/10.. Train loss: 1.852.. Elapsed time: 14.120505809783936
Epoch 1/10.. Train loss: 0.887.. Elapsed time: 14.210048198699951
Epoch 1/10.. Train loss: 0.618.. Elapsed time: 14.23294186592102
Epoch 1/10.. Train loss: 0.549.. Elapsed time: 14.212541103363037
Epoch 1/10.. Train loss: 0.519.. Elapsed time: 14.047481775283813
Epoch 1/10.. Train loss: 0.506.. Elapsed time: 14.060708045959473
Epoch 1/10.. Train loss: 0.347.. Elapsed time: 14.301624059677124
Epoch 1/10.. Train loss: 0.399.. Elapsed time: 13.9844491481781
Epoch 1/10.. Train loss: 0.361.. Elapsed time: 13.957871913909912
Epoch 1/10.. Train loss: 0.305.. Elapsed time: 14.164010763168335
Epoch 1/10.. Train loss: 0.296.. Elapsed time: 14.001536846160889
Epoch 1/10.. Train loss: 0.298.. Elapsed time: 14.019971132278442
Epoch 1/10.. Train loss: 0.271.. Elapsed time: 13.951345443725586
Epoch 1/10.. Train loss: 0.252.. Elapsed time: 14.037938594818115
Epoch 1/10.. Train loss: 0.283.. Elapsed time: 13.944657564163208
Epoch 1/10.. Train loss: 0.299.. Elapsed time: 13.977224826812744
Epoch 1/10.. Train loss: 0.219.. Elapsed time: 13.941975355148315
Epoch 1/10.. Train loss: 0.242.. Elapsed time: 13.936140060424805
Epoch 1/10.. Train loss: 0.244.. Elapsed time: 13.942122459411621
Epoch 1/10.. Train loss: 0.216.. Elapsed time: 13.960899114608765
Epoch 1/10.. Train loss: 0.186.. Elapsed time: 13.956881523132324
Epoch 1/10.. Train loss: 0.241.. Elapsed time: 13.944581985473633
Epoch 1/10.. Train loss: 0.203.. Elapsed time: 13.934357404708862
Epoch 1/10.. Train loss: 0.189.. Elapsed time: 13.938358306884766
Epoch 1/10.. Train loss: 0.181.. Elapsed time: 13.944468021392822
Epoch 1/10.. Train loss: 0.186.. Elapsed time: 13.946297407150269
Epoch 1/10.. Train loss: 0.164.. Elapsed time: 13.940366744995117
Epoch 1/10.. Train loss: 0.165.. Elapsed time: 13.938241720199585
Epoch 1/10.. Train loss: 0.176.. Elapsed time: 14.015569925308228
Epoch 1/10.. Train loss: 0.165.. Elapsed time: 14.019208669662476
Epoch 1/10.. Train loss: 0.175.. Elapsed time: 14.149503469467163
Epoch 1/10.. Train loss: 0.159.. Elapsed time: 14.128302097320557
Epoch 1/10.. Train loss: 0.155.. Elapsed time: 13.935027837753296
Epoch 1/10.. Train loss: 0.137.. Elapsed time: 13.937382221221924
Epoch 1/10.. Train loss: 0.127.. Elapsed time: 13.929635524749756
Epoch 1/10.. Train loss: 0.133.. Elapsed time: 13.935472011566162
Epoch 1/10.. Train loss: 0.152.. Elapsed time: 13.922808647155762
Epoch 1/10.. Validation loss: 0.032.. Validation accuracy: 0.239

我不会粘贴更多,因为它很长并且需要一段时间才能运行,但到第 3 个时期结束时,Colab 模型的损失仍在 0.5 左右反弹,而在本地达到 0.02。

如果有人能帮我解决这个问题,我将不胜感激。

解决方法

我通过将训练数据解压缩到 Google Drive 并从那里读取文件来解决这个问题,而不是使用 Colab 命令将文件夹直接解压缩到我的工作区。我完全不知道为什么会导致这个问题;对图像及其相应的张量进行快速目视检查看起来不错,但我无法逐一检查 6,000 个左右的图像来检查每一个。 如果有人知道这导致问题的原因,请告诉我!

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...