问题描述
我正在将 pytorch 网络从旧代码更新为当前代码。遵循诸如 here 之类的文档。
我曾经有过:
import torch
from torchtext import data
from torchtext import datasets
# setting the seed so our random output is actually deterministic
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
# defining our input fields (text) and labels.
# We use the Spacy function because it provides strong support for tokenization in languages other than English
TEXT = data.Field(tokenize = 'spacy',include_lengths = True)
LABEL = data.LabelField(dtype = torch.float)
from torchtext import datasets
train_data,test_data = datasets.IMDB.splits(TEXT,LABEL)
import random
train_data,valid_data = train_data.split(random_state = random.seed(SEED))
example = next(iter(test_data))
example.text
MAX_VOCAB_SIZE = 25_000
TEXT.build_vocab(train_data,max_size = MAX_VOCAB_SIZE,vectors = "glove.6B.100d",unk_init = torch.Tensor.normal_) # how to initialize unseen words not in glove
LABEL.build_vocab(train_data)
from torchtext.datasets import IMDB
train_data,test_data = IMDB(split=('train','test'))
我可以打印输出,虽然它们看起来不同(稍后有问题?),但它们具有所有信息。我可以用 next(train_data.
) 很好地打印 test_data然后在我这样做之后:
test_size = int(len(train_dataset)/2)
train_data,valid_data = torch.utils.data.random_split(train_dataset,[test_size,test_size])
它告诉我:
下一个(train_data)
TypeError: 'Subset' object is not an iterator
这让我觉得我在应用 random_split 时不正确。如何正确创建此数据集的验证集?不会引起问题。
解决方法
试试next(iter(train_data))
。似乎必须明确地在 dataset
上创建迭代器。并在需要有效性时使用 Dataloader
。