分类器对象没有属性 train

问题描述

我在 datatset.py 模块中收到此错误,该模块用于为 siamese 网络创建正负对。知道为什么给我这个错误吗? 我将这个 Siamese MNIST 称为我在训练数据集上的主要模块“itr1.py”,用于我自己的自定义数据集。

Error log

siamese_train_dataset = SiameseMNIST(train_dataset)
siamese_test_dataset = SiameseMNIST(test_dataset)


class SiameseMNIST(Dataset):
    def __init__(self,dataset,train=True):
        self.dataset = dataset
        self.train = self.dataset.train
        self.transform = self.dataset.transform

        if self.train:
            self.train_labels = self.dataset.train_labels
            self.train_data = self.dataset.train_data
            self.labels_set = set(self.train_labels.numpy())
            self.label_to_indices = {label: np.where(self.train_labels.numpy() == label)[0]
                                     for label in self.labels_set}
        else:
            # generate fixed pairs for testing
            self.test_labels = self.dataset.test_labels
            self.test_data = self.dataset.test_data
            self.labels_set = set(self.test_labels.numpy())
            self.label_to_indices = {label: np.where(self.test_labels.numpy() == label)[0]
                                     for label in self.labels_set}

            random_state = np.random.RandomState(29)

            positive_pairs = [[i,random_state.choice(self.label_to_indices[self.test_labels[i].item()]),1]
                              for i in range(0,len(self.test_data),2)]

            negative_pairs = [[i,random_state.choice(self.label_to_indices[
                                                       np.random.choice(
                                                           list(self.labels_set - set([self.test_labels[i].item()]))
                                                       )
                                                   ]),0]
                              for i in range(1,2)]
            self.test_pairs = positive_pairs + negative_pairs

    def __getitem__(self,index):
        if self.train:
            target = np.random.randint(0,2)
            img1,label1 = self.train_data[index],self.train_labels[index].item()
            if target == 1:
                siamese_index = index
                while siamese_index == index:
                    siamese_index = np.random.choice(self.label_to_indices[label1])
            else:
                siamese_label = np.random.choice(list(self.labels_set - set([label1])))
                siamese_index = np.random.choice(self.label_to_indices[siamese_label])
            img2 = self.train_data[siamese_index]
        else:
            img1 = self.test_data[self.test_pairs[index][0]]
            img2 = self.test_data[self.test_pairs[index][1]]
            target = self.test_pairs[index][2]

        img1 = Image.fromarray(img1.numpy(),mode='L')
        img2 = Image.fromarray(img2.numpy(),mode='L')
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
        return (img1,img2),target

    def __len__(self):
        return len(self.dataset)

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)