导出Keras中每个时期的嵌入

问题描述

我试图逐个访问Keras中嵌入层的输出(n维矢量)。似乎没有为此的特定回调。我已经尝试过Tensorboard回调,因为它提供了一个选项来记录每个时期的嵌入,但是当我找到日志文件时,我无法读取它们。它们可能是只能出于可视化目的而由Tensorboard访问的文件。我需要将嵌入矢量保存为以后可以在外部keras上使用的格式,例如TSV文件。有办法吗?

非常感谢!

解决方法

好的,因此我在Nazmul Hasan急需的有关如何格式化每个时期更新名称的格式的帮助下,找到了解决方法。本质上,我创建了一个自定义回调:

import io

encoder = info.features['text'].encoder

class CustomCallback(keras.callbacks.Callback):
    def on_epoch_end(self,epoch,logs=None):
        out_v = io.open('vecs_{}.tsv'.format(epoch),'w',encoding='utf-8')
        vec = model.layers[0].get_weights()[0] # skip 0,it's padding.
        out_v.write('\t'.join([str(x) for x in vec]) + "\n")
        out_v.close()