Stellargraph 无法处理数据洗牌

问题描述

当我使用 DGCNN 在图分类上运行 stellargraphdemo 时,我得到了与演示中相同的结果。

但是,当我使用以下代码测试第一次打乱数据时会发生什么时:

shuffler = list(zip(graphs,graph_labels))
random.shuffle(shuffler)
graphs,graph_labels = zip(*shuffler)

该模型根本没有学习(准确度约为 50% - 就像数据分布一样)。

有谁知道为什么会这样?也许我洗牌的方式不对?或者是首先应该对数据进行打乱(也是为什么?这没有任何意义)?还是 stellargraph 实现中的错误

解决方法

我发现了问题。这与改组算法无关,也与 StellarGraph 的实现无关。问题出在演示中,在以下几行:

train_gen = gen.flow(
    list(train_graphs.index - 1),targets=train_graphs.values,batch_size=50,symmetric_normalization=False,)

test_gen = gen.flow(
    list(test_graphs.index - 1),targets=test_graphs.values,batch_size=1,)

问题是由 train_graphs.index - 1test_graphs.index - 1 引起的。索引已经在 0n 之间的范围内,因此从它们中减去一个会导致图形数据向后“移动”一个,从而导致每个数据点获得不同数据点的标签.

要解决此问题,只需将它们更改为 train_graphs.indextest_graphs.index,末尾不带 -1