问题描述
我有一个尺寸为 [BATCH_SIZE,TIME_STEPS,EMbedDING_DIM]
的参差不齐的张量。我想用另一个形状为 [BATCH_SIZE,AUG_DIM]
的张量的数据来扩充最后一个轴。给定示例的每个时间步都增加了相同的值。
如果每个示例的张量都没有变化的 TIME_STEPS
,我可以简单地用 tf.repeat
重新塑造第二个张量,然后使用 tf.concat
:
import tensorflow as tf
# create data
# shape: [BATCH_SIZE,EMbedDING_DIM]
emb = tf.constant([[[1,2,3],[4,5,6]],[[1,[0,0]]])
# shape: [BATCH_SIZE,1,AUG_DIM]
aug = tf.constant([[[8]],[[9]]])
# concat
aug = tf.repeat(aug,emb.shape[1],axis=1)
emb_aug = tf.concat([emb,aug],axis=-1)
这在 emb
参差不齐时不起作用,因为 emb.shape[1]
未知且因示例而异:
# rag and remove padding
emb = tf.RaggedTensor.from_tensor(emb,padding=(0,0))
# reshape for augmentation - this doesn't work
aug = tf.repeat(aug,axis=1)
ValueError: 尝试将具有不受支持类型 (
目标是创建一个看起来像这样的参差不齐的张量 emb_aug
:
<tf.RaggedTensor [[[1,3,8],6,8]],9]]]>
有什么想法吗?
解决方法
执行此操作的最简单方法是使用 tf.RaggedTensor.to_tensor()
使不规则张量成为常规张量,然后执行其余的解决方案。我假设你需要张量保持参差不齐。关键是在你的参差不齐的张量中找到每个batch的row_lengths
,然后利用这些信息使你的增强张量参差不齐。
示例:
import tensorflow as tf
# data
emb = tf.constant([[[1,2,3],[4,5,6]],[[1,[0,0]]])
aug = tf.constant([[[8]],[[9]]])
# make embeddings ragged for testing
emb_r = tf.RaggedTensor.from_tensor(emb,padding=(0,0))
print(emb_r.shape)
# (2,None,3)
这里我们将使用 row_lengths
和 sequence_mask
的组合来创建一个新的参差不齐的张量。
# find the row lengths of the embeddings
rl = emb_r.row_lengths()
print(rl)
# tf.Tensor([2 1],shape=(2,),dtype=int64)
# find the biggest row length
max_rl = tf.math.reduce_max(rl)
print(max_rl)
# tf.Tensor(2,shape=(),dtype=int64)
# repeat the augmented data `max_rl` number of times
aug_t = tf.repeat(aug,repeats=max_rl,axis=1)
print(aug_t)
# tf.Tensor(
# [[[8]
# [8]]
#
# [[9]
# [9]]],1),dtype=int32)
# create a mask
msk = tf.sequence_mask(rl)
print(msk)
# tf.Tensor(
# [[ True True]
# [ True False]],2),dtype=bool)
从这里我们可以使用 tf.ragged.boolean_mask
使增强数据变得参差不齐
# make the augmented data a ragged tensor
aug_r = tf.ragged.boolean_mask(aug_t,msk)
print(aug_r)
# <tf.RaggedTensor [[[8],[8]],[[9]]]>
# concatenate!
output = tf.concat([emb_r,aug_r],2)
print(output)
# <tf.RaggedTensor [[[1,3,8],6,8]],9]]]>
您可以找到支持不规则张量的 tensorflow 方法列表 here