如何重塑 Tensorflow 数据集中的数据?

问题描述

我正在编写一个数据管道,将成批的时间序列序列和相应的标签输入到需要 3D 输入形状的 LSTM 模型中。我目前有以下几点:

def split(window):
    return window[:-label_length],window[-label_length]

dataset = tf.data.Dataset.from_tensor_slices(data.sin)
dataset = dataset.window(input_length + label_length,shift=label_shift,stride=1,drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(input_length + label_length))
dataset = dataset.map(split,num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.cache()
dataset = dataset.shuffle(shuffle_buffer,seed=shuffle_seed,reshuffle_each_iteration=False)
dataset = dataset.batch(batch_size=batch_size,drop_remainder=True)
dataset = dataset.prefetch(tf.data.AUTOTUNE)

for x,y in dataset.take(1): x.shape 的结果形状是 (32,20),其中 32 是批量大小,20 是序列的长度,但我需要 (32,20,1) 的形状,其中额外的维度表示特征。

我的问题是如何重塑,理想情况下是在缓存数据之前传递给 split 函数dataset.map 函数中?

解决方法

这很简单。在您的拆分功能中执行此操作

def split(window):
    return window[:-label_length,tf.newaxis],window[-label_length,tf.newaxis,tf.newaxis]