问题描述
我正在尝试在python的matlab中编写发现的squeezenet CNN。在尝试这样做时,出现了一个错误。我在github上找到了此代码的灵感,并且还将链接该人,以确保他们获得应有的荣誉。 https://github.com/chasingbob/squeezenet-keras
```
InvalidArgumentError Traceback (most recent call last)
<ipython-input-21-faad1ec2e17b> in <module>
23 validation_data=val_data_gen,24 #nb_val_samples=21,---> 25 callbacks=[checkpoint])
in fit_generator(self,generator,steps_per_epoch,epochs,verbose,callbacks,validation_data,validation_steps,validation_freq,class_weight,max_queue_size,workers,use_multiprocessing,shuffle,initial_epoch)
1294 shuffle=shuffle,1295 initial_epoch=initial_epoch,->1296 steps_name='steps_per_epoch')
1297
1298 def evaluate_generator(self,in model_iteration(model,data,initial_epoch,mode,batch_size,steps_name,**kwargs)
263
264 is_deferred = not model._is_compiled
--> 265 batch_outs = batch_function(*batch_data)
266 if not isinstance(batch_outs,list):
267 batch_outs = [batch_outs]
in train_on_batch(self,x,y,sample_weight,reset_metrics)
1015 self._update_sample_weight_modes(sample_weights=sample_weights)
1016 self._make_train_function()
-> 1017 outputs = self.train_function(ins) # pylint: disable=not-
callable
1018
1019 if reset_metrics:
in __call__(self,inputs)
3471 Feed_symbols != self._Feed_symbols or self.fetches !=
self._fetches or
3472 session != self._session):
-> 3473 self._make_callable(Feed_arrays,Feed_symbols,symbol_vals,session)
3474
3475 fetched = self._callable_fn(*array_vals,in _make_callable(self,Feed_arrays,session)
3408 callable_opts.run_options.copyFrom(self.run_options)
3409 # Create callable.
-> 3410 callable_fn = session._make_callable_from_options(callable_opts)
3411 # Cache parameters corresponding to the generated callable,so that
3412 # we can detect future mismatches and refresh the callable.
in _make_callable_from_options(self,callable_options)
1503 """
1504 self._extend_graph()
-> 1505 return BaseSession._Callable(self,callable_options)
1506
1507
in __init__(self,session,callable_options)
1458 try:
1459 self._handle = tf_session.TF_SessionMakeCallable(
-> 1460 session._session,options_ptr)
1461 finally:
1462 tf_session.TF_DeleteBuffer(options_ptr)
InvalidArgumentError: Default MaxPoolingOp only supports NHWC on device type
cpu
[[{{node maxpool4/MaxPool}}]]
这是代码
batch_size = 10
epochs = 18
IMG_HEIGHT = 227
IMG_WIDTH = 227
train_image_generator = ImageDataGenerator(rescale=1./255) # Generator for our
training data
validation_image_generator = ImageDataGenerator(rescale=1./255) # Generator for
our validation data
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,directory=train_dir,shuffle=False,target_size=(IMG_HEIGHT,IMG_WIDTH),class_mode='categorical')
val_data_gen = validation_image_generator.flow_from_directory(batch_size=batch_size,directory=validation_dir,class_mode='categorical')
def SqueezeNet(nb_classes,inputs=(227,227,3)):
# Keras Implementation of SqueezeNet(arXiv 1602.07360)
#Arguments:
#nb_classes: total number of final categories
#inputs -- shape of the input images (channel,cols,rows)
input_img = Input(shape=(227,3))
conv1 = Convolution2D(
96,7,activation='relu',kernel_initializer='glorot_uniform',#strides=(2,2),padding='same',name='conv1')(input_img)
maxpool1 = MaxPooling2D(
pool_size=(1,1),strides=(2,name='maxpool1')(conv1)
fire2_squeeze = Convolution2D(
16,1,name='fire2_squeeze')(maxpool1)
fire2_expand1 = Convolution2D(
64,3,name='fire2_expand1')(fire2_squeeze)
fire2_expand2 = Convolution2D(
64,name='fire2_expand2')(fire2_squeeze)
merge2 = concatenate(
[fire2_expand1,fire2_expand2],axis=1)
fire3_squeeze = Convolution2D(
16,name='fire3_squeeze')(merge2)
fire3_expand1 = Convolution2D(
64,name='fire3_expand1')(fire3_squeeze)
fire3_expand2 = Convolution2D(
64,name='fire3_expand2')(fire3_squeeze)
merge3 = concatenate(
[fire3_expand1,fire3_expand2],axis=1)
fire4_squeeze = Convolution2D(
32,name='fire4_squeeze')(merge3)
fire4_expand1 = Convolution2D(
128,name='fire4_expand1')(fire4_squeeze)
fire4_expand2 = Convolution2D(
128,name='fire4_expand2')(fire4_squeeze)
merge4 = concatenate(
[fire4_expand1,fire4_expand2],axis=1)
maxpool4 = MaxPooling2D(
pool_size=(1,name='maxpool4',data_format = 'channels_first')(merge4)
fire5_squeeze = Convolution2D(
32,name='fire5_squeeze')(maxpool4)
fire5_expand1 = Convolution2D(
128,name='fire5_expand1')(fire5_squeeze)
fire5_expand2 = Convolution2D(
128,name='fire5_expand2')(fire5_squeeze)
merge5 = concatenate(
[fire5_expand1,fire5_expand2],axis=1)
fire6_squeeze = Convolution2D(
48,name='fire6_squeeze')(merge5)
fire6_expand1 = Convolution2D(
192,name='fire6_expand1')(fire6_squeeze)
fire6_expand2 = Convolution2D(
192,name='fire6_expand2')(fire6_squeeze)
merge6 = concatenate(
[fire6_expand1,fire6_expand2],axis=1)
fire7_squeeze = Convolution2D(
48,name='fire7_squeeze')(merge6)
fire7_expand1 = Convolution2D(
192,name='fire7_expand1')(fire7_squeeze)
fire7_expand2 = Convolution2D(
192,name='fire7_expand2')(fire7_squeeze)
merge7 = concatenate(
[fire7_expand1,fire7_expand2],axis=1)
fire8_squeeze = Convolution2D(
64,name='fire8_squeeze')(merge7)
fire8_expand1 = Convolution2D(
256,name='fire8_expand1')(fire8_squeeze)
fire8_expand2 = Convolution2D(
256,name='fire8_expand2')(fire8_squeeze)
merge8 = concatenate(
[fire8_expand1,fire8_expand2],axis=1)
maxpool8 = MaxPooling2D(
pool_size=(1,name='maxpool8',data_format = 'channels_first')(merge8)
fire9_squeeze = Convolution2D(
64,name='fire9_squeeze')(maxpool8)
fire9_expand1 = Convolution2D(
256,name='fire9_expand1')(fire9_squeeze)
fire9_expand2 = Convolution2D(
256,name='fire9_expand2')(fire9_squeeze)
merge9 = concatenate(
[fire9_expand1,fire9_expand2],axis=1)
fire9_dropout = Dropout(0.6,name='fire9_dropout')(merge9)
conv10 = Convolution2D(
nb_classes,padding='valid',name='conv10')(fire9_dropout)
# The size should match the output of conv10
avgpool10 = AveragePooling2D((1,name='avgpool10')(conv10)
flatten = Flatten(name='flatten')(avgpool10)
softmax = Activation("softmax",name='softmax')(flatten)
return Model(inputs=input_img,outputs=softmax)
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.callbacks import EarlyStopping,ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator
model = SqueezeNet(2,3))
sgd = SGD(lr=0.01,decay=1e-6,momentum=0.9,nesterov=True)
model.compile(optimizer=sgd,#loss="mse",loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])
batch_size = 32
nb_classes = 10
nb_epoch = 200
early_stopping = EarlyStopping(monitor='val_loss',patience=3,verbose=0)
checkpoint = ModelCheckpoint(
'weights.{epoch:02d}-{val_loss:.2f}.h5',monitor='val_loss',verbose=0,save_best_only=True,save_weights_only=True,mode='min',period=1)
#model = CustomModel(inputs,outputs)
tf.config.threading.set_intra_op_parallelism_threads(2)
tf.config.threading.set_inter_op_parallelism_threads(2)
model.fit_generator(
train_data_gen,#samples_per_epoch=10,#epoch=18,validation_data=val_data_gen,#nb_val_samples=21,callbacks=[checkpoint])
解决方法
在此处为社区的利益发布答案。
将Tensorflow版本升级到 2.3 解决了该问题。
您可以使用以下行来升级Tensorflow版本。
pip install --user --upgrade tensorflow
此外,您还可以在代码开头尝试以下行。
--device=cpu --data_format=NHWC