使用 RaggedTensor 输入的 TensorFlow(2.4.0) 模型训练问题

问题描述

我正在尝试训练一个 TF 模型,该模型将查看客户关于食品类别的购物车价值,并预测客户接下来会对哪个食品类别感兴趣。

只有 5 种食物类别:

>> all_food_categories
['pizza','side','dessert','drink','dip']

购物车中已有的食品类别是字符串类型的 RaggedTensor

for i in generator(): is # generator returns a tuple,(features,target)
print(i)
break

(

{'categories_of_assumed_cart_flattened': <tf.RaggedTensor [[b'<blank>'],[b'pizza',b'dessert'],b'side',b'dip',b'dessert',b'pizza'],[b'side'],b'pizza',[b'pizza'],b'side'],[b'side',b'drink',b'dip'],b'drink'],[b'drink'],[b'<blank>'],b'pizza']]>,'disposition_type': <tf.Tensor: shape=(32,),dtype=string,numpy=
    array([b'collection',b'delivery',b'collection',b'collection'],dtype=object)>},array([0,4,1,3,2,1]))

型号代码

class NextItemCategory(tf.keras.Model):
    def __init__(self,vocab,mask_token = '',embed_dim=4,conv_kernels=[3,5],max_seq_len = 7):
        super(NextItemCategory,self).__init__()
        self.mask_token = mask_token
        self.max_seq_len = max_seq_len
        self.lookup = tf.keras.layers.experimental.preprocessing.StringLookup(vocabulary=vocab,mask_token=mask_token)
        self.embed = tf.keras.layers.Embedding(len(self.lookup.get_vocabulary()),4)
        self.model_layers = [tf.keras.layers.Conv2D(filters=1,kernel_size=[ck_i,embed_dim],padding='same') for ck_i in conv_kernels]
        self.pool = tf.keras.layers.GlobalMaxPool2D()
        self.dense = tf.keras.layers.Dense(5,activation='softmax')


    def call(self,inputs):

        inp = inputs["categories_of_assumed_cart_flattened"]
        x = inp.to_tensor(default_value='',shape = (None,self.max_seq_len))
        x = self.lookup(x)
        x = self.embed(x)
        x = tf.expand_dims(input=x,axis=-1)
        z = []
        for layer in self.model_layers:
            y = layer(x)
            y = self.pool(y)
            z.append(y)
        z = tf.concat(z,axis=-1)
        z = self.dense(z)
        return z

cart_model = NextItemCategory(all_food_categories)
cart_model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),optimizer='adam',metrics=['accuracy'])

train_data_generator = generator()
cart_model.fit(train_data_generator,verbose=1)

模型正在构建中。

Model: "next_item_category"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
string_lookup_20 (StringLook multiple                  0         
_________________________________________________________________
embedding_20 (Embedding)     multiple                  28        
_________________________________________________________________
conv2d_60 (Conv2D)           multiple                  13        
_________________________________________________________________
conv2d_61 (Conv2D)           multiple                  17        
_________________________________________________________________
conv2d_62 (Conv2D)           multiple                  21        
_________________________________________________________________
global_max_pooling2d_20 (Glo multiple                  0         
_________________________________________________________________
dense_20 (Dense)             multiple                  20        
=================================================================
Total params: 99
Trainable params: 99
Non-trainable params: 0

但作为 fit() 的一部分,我收到此错误

AttributeError                            Traceback (most recent call last)
------
<ipython-input-110-5d3dc4a7c2fa> in call(self,inputs)
     12 
     13         inp = inputs["categories_of_assumed_cart_flattened"]
---> 14         x = inp.to_tensor(default_value='',self.max_seq_len))
     15         x = self.lookup(x)
     16         x = self.embed(x)

AttributeError: 'Tensor' object has no attribute 'to_tensor'

卷积开始之前,我已经尝试将 to_tensor() 调用放在不同的代码位置,但同样的错误仍然存​​在。

生成器的输出可以清楚地看出,inputs['categories_of_assumed_cart_flattened'] 始终属于 tf.RaggedTensor 类型。

我不知道问题是什么;非常感谢任何帮助!非常感谢!

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)