如何调整自动编码器中的超参数?

问题描述

我正在处理一个不受监督的问题,每当我尝试使用GridSeachCv时,都会在 grid.fit(X,X)

建议您添加代码。

X是形状为(4700,128,3)的输入数组,其中包含4700张图像的图像嵌入。请有人帮忙,我在这个问题上停留了太久了,什么都没想到。


def make_model(lr=0.0001,init= tf.keras.initializers.HeNormal()):
    global autoencoder,encoder
    print('MAKING MODEL')
    input_img=Input(shape=(X.shape[1:]))

    e1 = Conv2D(64,(3,3),activation='relu',padding='same',kernel_initializer=init)(input_img)
    e1 = MaxPooling2D((2,2),padding='same')(e1)

    e1 = Conv2D(32,kernel_initializer=init)(e1)
    e1 = MaxPooling2D((2,padding='same')(e1)
    
    e1 = Conv2D(16,kernel_initializer=init)(e1)
    enc = MaxPooling2D((2,padding='same')(e1)
 
    
    encoder=Model(input_img,enc)
    
    
    dec = Conv2D(16,kernel_initializer=init)(enc)
    d1 = UpSampling2D((2,2))(dec)

    

    d1 = Conv2D(32,kernel_initializer=init)(d1)
    d1 = UpSampling2D((2,2))(d1)

    
    d1 = Conv2D(64,kernel_initializer=init)(d1)
    
    
    d1 = Conv2D(3,kernel_initializer=init)(d1)
    decoded=UpSampling2D((2,2))(d1)

    autoencoder = Model(input_img,decoded)

    autoencoder.compile(optimizer=Adam(lr),loss='binary_crossentropy',metrics=['binary_crossentropy'])
    
    
    return autoencoder


def driver():
    print('PROGRAM START')
    from sklearn.model_selection import GridSearchCV
    from keras.wrappers.scikit_learn import KerasClassifier

    model=KerasClassifier(build_fn=make_model)

    param=dict(batch_size=[2,4,8,16,32])

    grid=GridSearchCV(estimator=model,param_grid=param)
    grid_result=grid.fit(X,X)
    
    print('PROGRAM END')

ValueError跟踪(最近一次通话最近) 在()中 3 4 img_size = 128 ----> 5驱动程序()

2幅 适合的/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/wrappers/scikit_learn.py(self,x,y,** kwargs) 第219章 第220章 -> 221提高ValueError('y的无效形状:'+ str(y.shape)) 222 self.n_classes_ = len(self.classes_) 223 return super(KerasClassifier,self).fit(x,y,** kwargs)

ValueError:y的形状无效:(4700、128、128、3)

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...