如何提高 BERT keras hub 层输入的秩 (ndim) 以进行学习排名

问题描述

我正在尝试使用 tensorflow hub 上可用的预训练 BERT 来实现一个学习排名模型。我正在使用 ListNet 损失函数的变体,它要求每个训练实例都是与查询相关的几个排名文档的列表。我需要模型能够接受形状(batch_size、list_size、sentence_length)中的数据,其中模型在每个训练实例中的“list_size”轴上循环,返回等级并将它们传递给损失函数。在仅由密集层组成的简单模型中,这可以通过增加输入层的维度轻松实现。例如:

from tensorflow.keras.layers import Dense,Input
from tensorflow.keras import Model

input = Input([6,10])
x = Dense(20,activation='relu')(input)
output = Dense(1,activation='sigmoid')(x)
model = Model(inputs=input,outputs=output)

...现在模型将在计算损失和更新梯度之前对长度为 10 的向量执行 6 次前向传递。

我正在尝试对 BERT 模型及其预处理层执行相同的操作:

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text

bert_preprocess_model = hub.KerasLayer('https://tfhub.dev/tensorflow/small_bert/bert_en_uncased_L-4_H-512_A-8/1')
bert_model = hub.KerasLayer('https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/3')
    
text_input = tf.keras.layers.Input(shape=(),dtype=tf.string,name='text')
processed_input = bert_preprocess_model(text_input)
output = bert_model(processed_input)
model = tf.keras.Model(text_input,output)

但是当我尝试将 'text_input' 的形状更改为 (6) 或以任何方式对其进行干预时,它总是会导致相同类型的错误

 ValueError: Could not find matching function to call loaded from the SavedModel. Got:
      Positional arguments (3 total):
        * Tensor("inputs:0",shape=(None,6),dtype=string)
        * False
        * None
      Keyword arguments: {}
    
    Expected these arguments to match one of the following 4 option(s):
    
    Option 1:
      Positional arguments (3 total):
        * TensorSpec(shape=(None,),name='sentences')
        * False
        * None
      Keyword arguments: {}
     ....

根据https://www.tensorflow.org/hub/api_docs/python/hub/KerasLayer,您似乎可以通过 tf.keras.layers.InputSpec 配置 hub.KerasLayer 的输入形状。就我而言,我想应该是这样的:

bert_preprocess_model.input_spec = tf.keras.layers.InputSpec(ndim=2)
bert_model.input_spec = tf.keras.layers.InputSpec(ndim=2)

当我运行上面的代码时,属性确实发生了变化,但是在尝试构建模型时,出现了完全相同的错误

有没有什么方法可以轻松解决这个问题,而无需创建自定义训练循环?

解决方法

假设您有一批 B 个示例,每个示例都有 N 个文本字符串,这构成了形状为 [B,N] 的二维张量。使用 tf.reshape(),您可以将其转换为形状为 [B*N] 的一维张量,通过 BERT(保留输入顺序)发送,然后将其重新整形回 [B,N]。 (还有 tf.keras.layers.Reshape,但它对您隐藏了批次维度。)

如果不是每次都恰好是 N 个文本字符串,则您必须在旁边做一些记账(例如,将输入存储在 tf.RaggedTensor 中,在其 .values 上运行 BERT,并构造一个结果中具有相同 .row_splits 的新 RaggedTensor。)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...