连体网络为每对图像分配相同的标签

问题描述

我正在使用暹罗网络来识别输入图像是否相同。 问题在于,网络每次都会分配标签0(而不是0和1),导致每次精度为0.5,但我不知道为什么。

这是模型:

def Model():

  input_dim = (200,200,1)
  img_a = Input(shape = input_dim)
  img_b = Input(shape = input_dim)

  base_net = build_base_network(input_dim)

  features_a = base_net(img_a)
  features_b = base_net(img_b)

  distance = Lambda(euclidean_distance,output_shape = eucl_dist_output_shape)([features_a,features_b])
  model = Model(inputs=[img_a,img_b],outputs=distance)

  return model

这是距离

def contrastive_loss(y_true,y_pred):
  margin = 1
  return K.mean(y_true * K.square(y_pred) + (1 - y_true) * K.square(K.maximum(margin - y_pred,0)))

#Optimizer
rms = RMSprop()

#Distance
def euclidean_distance(vects):
  x,y = vects
  return K.sqrt(K.sum(K.square(x - y),axis=1,keepdims=True))

def eucl_dist_output_shape(shapes):
    shape1,shape2 = shapes
    return (shape1[0],1)

训练和X / y形状

#Pair_equal --> every element is a tuple of numpy arrays representing same images
#Pair_diff --> every element is a tuple of numpy arrays representing different images
#y_equal --> for every Pair equal's element,contains 0
#y_diff --> for every Pair equal's element,contains 1

  if len(Pair_equal) > len(Pair_diff):
    Pair_equal = Pair_equal[0:len(Pair_diff)]
    y_equal = y_equal[0:len(y_diff)]

  elif len(Pair_equal) < len(Pair_diff):
    Pair_diff = Pair_diff[0:len(Pair_equal)]
    y_diff = y_diff[0:len(y_equal)]

  Pair_equal = np.array(Pair_equal)
  Pair_diff = np.array(Pair_diff)
  y_equal = np.array(y_equal)
  y_diff = np.array(y_diff)

  X = np.concatenate([Pair_equal,Pair_diff],axis=0)
  y = np.concatenate([y_equal,y_diff],axis=0)

  y = y.reshape(-1,1)

  #index shuffling
  indices = np.arange(X.shape[0])
  np.random.shuffle(indices)

  X = X[indices]
  y = y[indices]
 
  #X shape: (32,2,1)
  #y shape: (32,1)

  return [X[:,...],X[:,1,...]],y

预先感谢您的帮助。

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

错误1:Request method ‘DELETE‘ not supported 错误还原:...
错误1:启动docker镜像时报错:Error response from daemon:...
错误1:private field ‘xxx‘ is never assigned 按Alt...
报错如下,通过源不能下载,最后警告pip需升级版本 Requirem...