如何在TensorFlow中切片数据集?

问题描述

我想在tf.data中切片数据集。我的数据是这样的:

dataset = tf.data.Dataset.from_tensor_slices([[0,1,2,3,4],[1,4,5],[2,5,6],[3,6,7],[4,7,8]])

那么主要数据是:

[0 1 2 3 4]
[1 2 3 4 5]
[2 3 4 5 6]
[3 4 5 6 7]
[4 5 6 7 8]

我想创建其他包含如下数据的张量数据集:

       [[1,2],3],[5,6]]

在numpy中是这样的:

dataset[:,1:3]

如何在TensorFlow中做到这一点?

更新

我这样做是

dataset2 = dataset.map(lambda data: data[1:3])
for val in dataset2:
    print(val.numpy())

但是我认为有很好的解决方案。

解决方法

我认为您的解决方案是最好的解决方案。为了社区的利益,我正在使用tf.data.Dataset的{​​{3}}方法对数据集进行切片(对代码进行小的语法更改)。

请参考下面的代码

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([[0,1,2,3,4],[1,4,5],[2,5,6],[3,6,7],[4,7,8]])


dataset2 = dataset.map(lambda data: data[1:3])
for val in dataset2.as_numpy_iterator():
    print(val)

输出:

[1 2]
[2 3]
[3 4]
[4 5]
[5 6]

相关问答

依赖报错 idea导入项目后依赖报错,解决方案:https://blog....
错误1:代码生成器依赖和mybatis依赖冲突 启动项目时报错如下...
错误1:gradle项目控制台输出为乱码 # 解决方案:https://bl...
错误还原:在查询的过程中,传入的workType为0时,该条件不起...
报错如下,gcc版本太低 ^ server.c:5346:31: 错误:‘struct...