如何使用Keras点函数训练余弦相似度?

问题描述

在比较两个文档的嵌入时,我想在Keras模型中使用余弦相似度。

  • 培训目标应该在哪个值范围内?这意味着Dot函数的结果在哪个值范围内?如果嵌入非常相似,目标应该是什么值;如果嵌入非常不同,目标应该是什么?
  • 在这种情况下,轴参数是什么意思?设置axis = 1是否正确?

我的模型如下:

from tensorflow.python.keras.layers import Dot
from tensorflow.keras import Input,Model
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential

input_document1 = Input(200)
model_1 = Sequential()
model_1.add(Dense(100,activation='relu'))
encoded_document1 = model_1(input_document1)

input_document2 = Input(200)
model_2 = Sequential()
model_2.add(Dense(100,activation='relu'))
encoded_document2 = model_2(input_document2)

distance_layer = Dot(axes=1,normalize=True) # cosine proximity
prediction = distance_layer([encoded_document1,encoded_document2])

siamese_net = Model(inputs=[input_document1,input_document2],outputs=prediction)

documentation中,他们说:

是否在获取点积之前沿点积轴L2归一化样本。如果设置为True,则点积的输出为两个样本之间的余弦近似度

解决方法

[-1,1]之间的输出值之间的余弦相似度,其中-1表示矢量完全相反,而1表示完美的重叠/相似性。

您可以在这里查看如何计算KerasComputing cosine similarity between two tensors in Keras

中的余弦相似度

关于axex /轴的解释,此处的文档做得很好:https://www.tensorflow.org/guide/tensor

相关问答

错误1:Request method ‘DELETE‘ not supported 错误还原:...
错误1:启动docker镜像时报错:Error response from daemon:...
错误1:private field ‘xxx‘ is never assigned 按Alt...
报错如下,通过源不能下载,最后警告pip需升级版本 Requirem...