问题描述
我有一些混合模型的代码,一个在高效的网络上训练,其余的在我组合的一些外部数据上进行训练。以下是该模型的示例:
def create_model():
# Define parameters
inputShape = (256,256,3)
inputDim = 8
# define MLP network
model = Sequential()
model.add(Dense(8,input_dim=inputDim,activation="relu"))
model.add(Dense(4,activation="relu"))
cnnModel = Sequential()
cnnModel.add(EfficientNetB5(include_top = False,input_shape=inputShape))
cnnModel.add(Flatten())
cnnModel.add(Dense(units = 16,activation='relu'))
cnnModel.add(Dense(units = 4,activation='relu'))
# Concatenate them
fullModel = concatenate([cnnModel.output,model.output])
fullModel = Dense(4,activation="relu")(fullModel)
fullModel = Dense(1,activation="sigmoid")(fullModel)
model = Model(inputs=[cnnModel.input,model.input],outputs=fullModel)
return model
但是,当我通过 fit_generator 函数运行它时,我收到以下错误:
batch_size = 16
train_steps = TrainData.shape[0]//batch_size
valid_steps = TrainData.shape[0]//batch_size
model = create_model()
opt = Adam(lr=1e-3,decay=1e-3 / 200)
model.compile(loss="binary_crossentropy",optimizer=opt)
print("[INFO] training model...")
model.fit_generator(
train_dl,epochs=3,steps_per_epoch = train_steps
)
model.save("models/final_model")
InvalidArgumentError: Incompatible shapes: [16,3,256] vs. [1,1,3]
[[node model_47/efficientnetb5/normalization_52/sub (defined at <ipython-input-262-76be6a4af4a4>:11) ]] [Op:__inference_train_function_1072272]
我不确定这个错误是从哪里来的,无论是在数据加载器中还是在高效网络中。有什么想法吗?
编辑以包含数据加载器:
def data_generator(image_dir,dataframe,min_max,binary,category,transforms = None,batch_size = 16):
i = 0
samples_per_epoch = dataframe.shape[0]
number_of_batches = samples_per_epoch/batch_size
while True:
batch = {'images': [],'data': [],'labels': []} # use a dict for multiple inputs
# Randomly sample images in dataframe
idx = i
img_path = f"{image_dir}/{dataframe.iloc[idx]['image_name']}.jpg"
img = Image.open(img_path)
if transforms:
img = transforms(**{"image": np.array(img)})["image"]
img = np.asarray( img,dtype="int32" )
# make data into tensors
dataframe2 = dataframe.iloc[idx]
data_cont = min_max.transform(np.array(dataframe2['age_approx']).reshape(1,-1))
data_bina = binary.transform(dataframe2['sex'])
data_cate = category.transform(dataframe2['anatom_site_general_challenge'])
data_total = np.concatenate((data_cont,data_bina,data_cate),axis = 1)
label = dataframe2['target']
batch['images'].append(img)
batch['data'].append(data_total)
batch['labels'].append(label)
batch['images'] = np.array(batch['images']) # convert each list to array
batch['data'] = np.array(batch_x['data'])
batch['labels'] = np.array(batch['labels'])
i += 1
if counter >= number_of_batches:
counter = 0
yield [batch['images'],batch['data']],batch['labels']
def get_data(train_df,valid_df,train_tfms,test_tfms,batch_size,category):
train_dl = data_generator(image_dir='train/',dataframe = train_df,batch_size = batch_size,min_max = min_max,binary = binary,category = category,transforms = train_tfms)
valid_dl = data_generator(image_dir='train/',dataframe = valid_df,batch_size = batch_size*2,transforms = test_tfms)
return train_dl,valid_dl
当我只使用图像和高效网络时,我似乎遇到了同样的问题。似乎使用 Keras 内置的图像数据加载器功能是我让它工作的唯一方法(仅使用图像)。
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)