广播和连接参差不齐的张量

问题描述

我有一个尺寸为 [BATCH_SIZE,TIME_STEPS,EMbedDING_DIM] 的参差不齐的张量。我想用另一个形状为 [BATCH_SIZE,AUG_DIM] 的张量的数据来扩充最后一个轴。给定示例的每个时间步都增加了相同的值。

如果每个示例的张量都没有变化的 TIME_STEPS,我可以简单地用 tf.repeat 重新塑造第二个张量,然后使用 tf.concat

import tensorflow as tf


# create data
# shape: [BATCH_SIZE,EMbedDING_DIM]
emb = tf.constant([[[1,2,3],[4,5,6]],[[1,[0,0]]])
# shape: [BATCH_SIZE,1,AUG_DIM]
aug = tf.constant([[[8]],[[9]]])

# concat
aug = tf.repeat(aug,emb.shape[1],axis=1)
emb_aug = tf.concat([emb,aug],axis=-1)

这在 emb 参差不齐时不起作用,因为 emb.shape[1] 未知且因示例而异:

# rag and remove padding
emb = tf.RaggedTensor.from_tensor(emb,padding=(0,0))

# reshape for augmentation - this doesn't work
aug = tf.repeat(aug,axis=1)

ValueError: 尝试将具有不受支持类型 () 的值 (None) 转换为张量。

目标是创建一个看起来像这样的参差不齐的张量 emb_aug

<tf.RaggedTensor [[[1,3,8],6,8]],9]]]>

有什么想法吗?

解决方法

执行此操作的最简单方法是使用 tf.RaggedTensor.to_tensor() 使不规则张量成为常规张量,然后执行其余的解决方案。我假设你需要张量保持参差不齐。关键是在你的参差不齐的张量中找到每个batch的row_lengths,然后利用这些信息使你的增强张量参差不齐。

示例

import tensorflow as tf


# data
emb = tf.constant([[[1,2,3],[4,5,6]],[[1,[0,0]]])
aug = tf.constant([[[8]],[[9]]])

# make embeddings ragged for testing
emb_r = tf.RaggedTensor.from_tensor(emb,padding=(0,0))

print(emb_r.shape)
# (2,None,3)

这里我们将使用 row_lengthssequence_mask 的组合来创建一个新的参差不齐的张量。

# find the row lengths of the embeddings
rl = emb_r.row_lengths()

print(rl)
# tf.Tensor([2 1],shape=(2,),dtype=int64)

# find the biggest row length
max_rl = tf.math.reduce_max(rl)

print(max_rl)
# tf.Tensor(2,shape=(),dtype=int64)

# repeat the augmented data `max_rl` number of times
aug_t = tf.repeat(aug,repeats=max_rl,axis=1)

print(aug_t)
# tf.Tensor(
# [[[8]
#   [8]]
# 
#  [[9]
#   [9]]],1),dtype=int32)

# create a mask
msk = tf.sequence_mask(rl)

print(msk)
# tf.Tensor(
# [[ True  True]
#  [ True False]],2),dtype=bool)

从这里我们可以使用 tf.ragged.boolean_mask 使增强数据变得参差不齐

# make the augmented data a ragged tensor
aug_r = tf.ragged.boolean_mask(aug_t,msk)
print(aug_r)
# <tf.RaggedTensor [[[8],[8]],[[9]]]>

# concatenate!
output = tf.concat([emb_r,aug_r],2)
print(output)
# <tf.RaggedTensor [[[1,3,8],6,8]],9]]]>

您可以找到支持不规则张量的 tensorflow 方法列表 here