“子集”对象不是用于更新火炬的旧 IMDB 数据集的迭代器

问题描述

我正在将 pytorch 网络从旧代码更新为当前代码。遵循诸如 here 之类的文档。

我曾经有过:

import torch
from torchtext import data
from torchtext import datasets

# setting the seed so our random output is actually deterministic
SEED = 1234

torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

# defining our input fields (text) and labels. 
# We use the Spacy function because it provides strong support for tokenization in languages other than English
TEXT = data.Field(tokenize = 'spacy',include_lengths = True)
LABEL = data.LabelField(dtype = torch.float)

from torchtext import datasets
train_data,test_data = datasets.IMDB.splits(TEXT,LABEL)

import random
train_data,valid_data = train_data.split(random_state = random.seed(SEED))

example = next(iter(test_data))
example.text

MAX_VOCAB_SIZE = 25_000

TEXT.build_vocab(train_data,max_size = MAX_VOCAB_SIZE,vectors = "glove.6B.100d",unk_init = torch.Tensor.normal_) # how to initialize unseen words not in glove

LABEL.build_vocab(train_data)

现在在新代码中,我正在努力添加验证集。一切顺利,直到这里:

from torchtext.datasets import IMDB
train_data,test_data = IMDB(split=('train','test'))

我可以打印输出,虽然它们看起来不同(稍后有问题?),但它们具有所有信息。我可以用 next(train_data.

) 很好地打印 test_data

然后在我这样做之后:

test_size = int(len(train_dataset)/2)
train_data,valid_data = torch.utils.data.random_split(train_dataset,[test_size,test_size])

它告诉我:

一个(train_data)

TypeError: 'Subset' object is not an iterator

这让我觉得我在应用 random_split 时不正确。如何正确创建此数据集的验证集?不会引起问题。

解决方法

试试next(iter(train_data))。似乎必须明确地在 dataset 上创建迭代器。并在需要有效性时使用 Dataloader

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...