问题描述
我曾经在PyTorch工作,但现在必须学习Tensorflow才能完成工作。我试图通过创建一个简单的密集网络并在MNIST数据集上对其进行训练来加快速度,但是我无法对其进行训练。我的超级简单代码:
import tensorflow as tf
from tensorflow.keras import Input
from tensorflow.keras.layers import Dense,Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.utils import to_categorical
# Load mnist data from keras
(train_data,train_label),(test_data,test_label) = tf.keras.datasets.mnist.load_data(path="mnist.npz")
train_label,test_label = to_categorical(train_label),to_categorical(test_label)
train_data,train_label,test_data,test_label = Flatten()(train_data),Flatten()(train_label),Flatten()(test_data),Flatten()(test_label)
# Create generic SGD optimizer (no learning schedule)
optimizer = SGD(learning_rate = 0.01)
# Define function to build and compile model
def build_mnist_model(input_shape,batch_size = 30):
input_img = Input(shape = input_shape,batch_size = batch_size)
# Pass through dense layer
x = Dense(200,activation = 'relu',use_bias = True)(input_img)
x = Dense(400,use_bias = True)(x)
scores = Dense(10,activation = 'softmax',use_bias = True)(x)
# Create and compile tf model
mnist_model = Model(input_img,scores)
mnist_model.compile(optimizer = optimizer,loss = 'categorical_crossentropy')
return mnist_model
# Build the model
mnist_model = build_mnist_model(train_data[0].shape)
# Train the model
mnist_model.fit(
x = train_data,y = train_label,batch_size = 30,epochs = 20,verbose = 2,shuffle = True,# steps_per_epoch = 200
)
运行此命令我会得到
ValueError: When using data tensors as input to a model,you should specify the `steps_per_epoch` argument.
这对我来说真的没有意义,因为我的train_data
和train_label
只是常规张量,在这种情况下,根据Tensorflow文档,它应该默认为数据集中的样本数除以批量(在我的情况下为200)。
无论如何,我在致电steps_per_epoch = 200
时尝试指定mnist_model.fit()
,但随后出现另一个错误:
InvalidArgumentError: Incompatible shapes: [60000,10] vs. [30,1]
[[{{node training_4/SGD/gradients/gradients/loss_5/dense_17_loss/softmax_cross_entropy_with_logits_grad/mul}}]]
我似乎无法识别尺寸不匹配的来源。在PyTorch中,我习惯于手动创建批处理(通过对数据和标签张量进行子索引),但是在Tensorflow中,这似乎是自动发生的。因此,这使我对什么批次的尺寸错误,如何获得错误的尺寸等感到困惑。我希望这个简单的模型比我制作起来容易得多,而且我还不知道Tensorflow的窍门。 / p>
感谢您的帮助。
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)