如何在 pythorch 中创建增强数据集

问题描述

我必须添加到原始 CIFAR 数据集,对于每个图像,对应的图像,旋转 90 度。这个想法是创建一个 RotationDateset,一个扩展 datasets.VisionDataset 的类,它采用 CIFAR 并执行上述操作。

<!DOCTYPE html>
<html lang="en">
<head>
  <Meta charset="UTF-8">
  <Meta http-equiv="X-UA-Compatible" content="IE=edge">
  <Meta name="viewport" content="width=device-width,initial-scale=1.0">
  <script src="script.js"></script>
  <title>Document</title>
</head>
<body>
  <button type="button" onclick="changeVar('w')">wheels Button</button>
  <button type="button" onclick="changeVar('c')">BMW Button</button>
</body>
</html>

//org_dataset 是 CIFAR //num_rots 为 4 //转换是transforms.Compose([transforms.ToTensor(),transforms.normalize((0.5,0.5,0.5),(0.5,0.5))])

from __future__ import print_function,division
import skimage.io

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
from torchvision.datasets import ImageFolder
import torchvision
from torchvision import datasets,models,transforms
import matplotlib.pyplot as plt
import time
import os
from sklearn.model_selection import train_test_split
import copy
import cv2
from torchvision.models.resnet import BasicBlock
from torchvision.models.resnet import resnet
from PIL import Image
import xml.etree.ElementTree as ET
from torch.utils.model_zoo import load_url as load_state_dict_from_url
from torchvision.models.resnet import model_urls

这是我最初导入和转换 CIFAR 的方法

class RotDataset(datasets.VisionDataset):
    def __init__(self,org_dataset,transforms,num_rots):
        
        self.samples = org_dataset.data
        self.targets = []
        self.num_rots = num_rots
        self.transforms = transforms

        for k in self.samples:
          self.targets.append(k)

          for i in range(0,self.num_rots):
            tr = torchvision.transforms.Compose([torchvision.transforms.Randomrotation(degrees=90*i),torchvision.transforms.ToTensor(),torchvision.transforms.normalize((0.5,0.5))])
            # from PIL import Image
            p_i = Image.fromarray(k)
            te = tr(p_i)
            r_im = torch.reshape(te,(k.shape))
            r_im = np.array(r_im)
            self.targets.append(r_im)
            
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self,index):
      imgs = self.targets[index:index + self.num_rots]
      labels = list(range(0,self.num_rots))

      return imgs,labels

这是我创建 CIFAR 增强的方法

transform = transforms.Compose([transforms.ToTensor(),0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)


testset = torchvision.datasets.CIFAR10(root='./data',train=False,transform=transform)

classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')

和模型训练

cifar_rot = RotDataset(trainset,trainset.transforms,4)

rot_train,rot_val= train_test_split(
np.arange(len(cifar_rot.targets)),test_size=0.2,shuffle=True,)

train_sampler_rot = torch.utils.data.SubsetRandomSampler(rot_train)
val_sampler_rot = torch.utils.data.SubsetRandomSampler(rot_val)

DataLoaders_rot = {'train': torch.utils.data.DataLoader(cifar_rot,batch_size=8,sampler=train_sampler_rot),'val':torch.utils.data.DataLoader(cifar_rot,sampler=val_sampler_rot)}

sizes_rot = {'train':len(rot_train)*4,'val':len(rot_val)*4}

//问题是我启动模型的时候pythorch抛出这个错误

model_rot = torchvision.models.resnet34(pretrained=False) 

num_ftrs = model_rot.fc.in_features
output_dim_rot = 4 # since are 4 rotations

model_rot.fc = nn.Linear(num_ftrs,output_dim_rot)

model_rot = model_rot.to(device)
criterion = nn.CrossEntropyLoss()

optimizer_conv = optim.SGD(model_rot.parameters(),lr=0.001,momentum=0.9)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv,step_size=7,gamma=0.1)
model_rot = train_model(model_rot,criterion,optimizer_conv,exp_lr_scheduler,DataLoaders_rot,sizes_rot,num_epochs=10)

torch.save(model_rot.state_dict(),'rotation_resnet34_10_epochs.pt')

谁能帮帮我?提前致谢

解决方法

问题来自于您对 org_dataset.data 的依赖,这是一个形状为 (N,32,3) 的 numpy 数组(您希望它是 (N,3,32)

因此,使用 self.targets.append(k) 行,您在目标列表中放置了不正确的形状。然后,张量 te 具有正确的形状(感谢 ToTensor),但是您在

我还想指出,诸如 RandomRotation 之类的随机变换通常应用于 __getitem__ 方法中,而不是 __init__ 中。由于在这些变换中生成随机数,因此您希望每个时期都生成新样本,以便拥有几乎无限的数据集和样本。我实际上不确定您是否理解 RandomRotation 的作用:它使用 随机旋转 旋转输入张量,您只需指定可能的角度范围。因此,完全有可能应用参数 180 (i=2) 的“旋转”将产生几乎不变的张量。我看到您之后试图预测 i 的值,它很可能不起作用。您可能想改用 torch.rot90

除此之外,由于您已经在 ToTensor 中应用了 NormalizeRotDataset,因此您当然不需要在 CIFAR10 中使用它们。

最后评论:我真的不明白你为什么要 __getitem 返回张量(和标签)列表。我会在下面的代码中保持这种方式,但看起来它最终会破坏某些东西。

因此,您可以通过以下方式更正代码:

class RotDataset(datasets.VisionDataset):
    def __init__(self,org_dataset,transforms,num_rots):
    
        # Let's buffer the underlying dataset,we will sample   
        # from it on the fly
        self.dataset = org_dataset
        self.num_rots = num_rots
        # You did not use this attribute previously,probably a mistake
        # It will now be applied in the __getitem__
        self.transforms = transforms
        
    def __len__(self):
        # Typical front dataset : size is the same as the 
        # underlying dataset size
        return len(self.dataset)

    def __getitem__(self,index):
        # sampling from CIFAR10
        sample = self.dataset[index]
        # Because you want to return a list
        imgs = []
        for i in range(0,self.num_rots):
            # Creating the corresponding rotation
            rotation = torchvision.transforms.RandomRotation(degrees=90*i)
            # Applying rotation,followed by other transforms (toTensor,Normalize...)
            transformed = self.transform(rotation(sample))
            imgs.append(transformed)

        # Cleaner way to generate your range : 
        labels = np.arange(self.num_rots)

        return imgs,labels

# transform=None,since we will apply them in RotDataset
trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=None)
# The transforms to call in RotDataset
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5))])
cifar_rot = RotDataset(trainset,transform,4)

# using torch's random split to remove dependency on sklearn
from torch.utils.data import random_split
test_size = 0.2*len(cifar_rot)
rot_train,rot_val= random_split(cifar_rot,[len(cifar_rot)-test_size,test_size])

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...