问题描述
如何使用Trax设置多元回归问题?
我从下面的代码中从AssertionError: Invalid shape (16,2); expected (16,).
对象中获得了L2Loss
。
以下是我尝试使sentiment analysis example适应回归问题:
import os
import trax
from trax import layers as tl
from trax.supervised import training
import numpy
import random
#train_stream = trax.data.TFDS('imdb_reviews',keys=('text','label'),train=True)()
#eval_stream = trax.data.TFDS('imdb_reviews',train=False)()
def generate_samples():
# (text,lat/lon)
data= [
("Aberdeen MS",numpy.array((33.824742,-88.554591)) ),("Aberdeen SD",numpy.array((45.463186,-98.471033))),("Aberdeen WA",numpy.array((46.976432,-123.795781))),("Amite City LA",numpy.array((30.733723,-90.5208))),("Amory MS",numpy.array((33.984789,-88.48001))),("Amouli AS",numpy.array((-14.26556,-170.589772))),("Amsterdam NY",numpy.array((42.953149,-74.19505)))
]
for i in range(1024*8):
yield random.choice(data)
train_stream = generate_samples()
eval_stream = generate_samples()
model = tl.Serial(
tl.Embedding(vocab_size=8192,d_feature=256),tl.Mean(axis=1),# Average on axis 1 (length of sentence).
tl.Dense(2),# Regress to lat/lon
# tl.Logsoftmax() # Produce log-probabilities.
)
# You can print model structure.
print(model)
print(next(train_stream)) # See one example.
data_pipeline = trax.data.Serial(
trax.data.Tokenize(vocab_file='en_8k.subword',keys=[0]),trax.data.Shuffle(),# trax.data.FilterByLength(max_length=2048,length_keys=[0]),trax.data.BucketByLength(boundaries=[ 8,128,],batch_sizes=[256,64,4],trax.data.AddLossWeights()
)
train_batches_stream = data_pipeline(train_stream)
eval_batches_stream = data_pipeline(eval_stream)
example_batch = next(train_batches_stream)
print(f'shapes = {[x.shape for x in example_batch]}') # Check the shapes.:wq
# Training task.
train_task = training.TrainTask(
labeled_data=train_batches_stream,# loss_layer=tl.CrossEntropyLoss(),loss_layer=tl.L2Loss(),optimizer=trax.optimizers.Adam(0.01),n_steps_per_checkpoint=500,)
# Evaluaton task.
eval_task = training.EvalTask(
labeled_data=eval_batches_stream,metrics=[tl.L2Loss(),n_eval_batches=20 # For less variance in eval numbers.
)
# Training loop saves checkpoints to output_dir.
output_dir = os.path.expanduser('~/output_dir/')
training_loop = training.Loop(model,train_task,eval_tasks=[eval_task],output_dir=output_dir)
# Run 2000 steps (batches).
training_loop.run(2000)
解决方法
问题可能可能在generate_samples()
生成器中:这仅产生1024*8
(=8192
)个样本。如果我替换行
for i in range(1024*8):
作者
while True:
这样可以生成无限数量的样本,您的示例可以在我的机器上运行。
由于generate_samples()
仅产生8192
个样本,因此train_batches_stream
仅产生32
个批次,每个批次256
个样本,因此您最多只能训练{ {1}}个步骤。但是,您要求执行32
步骤。