问题描述
当我使用 DGCNN 在图分类上运行 stellargraph 的 demo 时,我得到了与演示中相同的结果。
但是,当我使用以下代码测试第一次打乱数据时会发生什么时:
shuffler = list(zip(graphs,graph_labels))
random.shuffle(shuffler)
graphs,graph_labels = zip(*shuffler)
该模型根本没有学习(准确度约为 50% - 就像数据分布一样)。
有谁知道为什么会这样?也许我洗牌的方式不对?或者是首先应该对数据进行打乱(也是为什么?这没有任何意义)?还是 stellargraph 实现中的错误?
解决方法
我发现了问题。这与改组算法无关,也与 StellarGraph 的实现无关。问题出在演示中,在以下几行:
train_gen = gen.flow(
list(train_graphs.index - 1),targets=train_graphs.values,batch_size=50,symmetric_normalization=False,)
test_gen = gen.flow(
list(test_graphs.index - 1),targets=test_graphs.values,batch_size=1,)
问题是由 train_graphs.index - 1
和 test_graphs.index - 1
引起的。索引已经在 0
到 n
之间的范围内,因此从它们中减去一个会导致图形数据向后“移动”一个,从而导致每个数据点获得不同数据点的标签.
要解决此问题,只需将它们更改为 train_graphs.index
和 test_graphs.index
,末尾不带 -1
。