加快 GPT2 上的推理时间 - 优化 tf.sess.run()

问题描述

我正在尝试优化 GPT2 上的推理时间。在 Google Colab 上调用脚本后生成样本的当前时间为 55 秒。我输入时间戳以尝试隔离瓶颈所在。 这是代码

 for _ in range(nsamples // batch_size):
            out = sess.run(output,Feed_dict={
                context: [context_tokens for _ in range(batch_size)]
            })[:,len(context_tokens):]
            for i in range(batch_size):
                generated += 1
                text = enc.decode(out[i])
                print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40)
                print(text)
        print("=" * 80)

线

out = sess.run(output,len(context_tokens):] 

是复杂性所在。有没有人有什么办法可以改进这段代码?非常感谢!

解决方法

batch_size 在 GPT2 中设置为 1,并且无法在不使进程崩溃的情况下更改它。所以“[context_tokens for _ in range(batch_size)]”的意思是“[context_tokens for _ in range(1)]”的意思是“[context_tokens]”,这不会大大提高速度,但可以安全地实现并查看代码更理智一点。真正的复杂之处在于您的 ram 中有一个 6 GB 的 bohemoth,您正在该会话中访问它。

实际上,您发送的令牌越少,这些令牌的处理越少,这部分执行的速度就越快。因为每个令牌都需要通过 GPT2 AI 发送。但结果是响应越不“智能”。

顺便说一下 // 是整数除法运算,所以 nsamples // batch_size = nsamples/1 = nsamples 大小。从我所看到的,当我在 print(nsamples) 中打印它的值时,nsamples 是 1。所以for循环是一个项目的另一个循环,这意味着循环可以被删除。

GPT2 只是 tensorflow 的一个实现。查找:如何在tensorflow中制作图形;如何为该图调用会话;如何使保护程序保存该会话中的变量以及如何使用保护程序恢复会话。您将了解检查点、元文件和其他使您的文件更有意义的实现。

tensorflow 模块位于 Lib、site-packages、tensorflow_core(至少在 AI Dungeon 2 Henk717 fork 中)。大多数处理发生在子目录 python/ops 和框架中。如果您的编码破坏了 tf 预期的钩子,您将看到这些弹出窗口。

如果这个问题与 AI Dungeon 中的实现有关,那么我能够实现的最好方法是递归调用 generator.generate,该调用由除 KeyboardInterrupt 之外的尝试退出: with a print(token,end = '',flush = True) 为每个令牌生成。通过这种方式,您可以在 AI 生成每个令牌时查看每个令牌,而不是等待 55 秒的 ping 声音。

此外,Cuda 警告需要单引号,而不是双引号, 导入操作系统 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 不是“3” 这将在导入 tensorflow 时取消 cuda 警告。

接下来,在 1.5 以上的 tensorflow 版本中,GPT2 的实现会弹出折旧。

关闭那些 tfv = tf.compat.v1 tfv.set_verbosity(tfv.logging.Error) 是你所需要的全部。您不需要导入警告。

即便如此,在 tf 初始化、样本初始生成和模块加载到 ram 之间的加载时间也很长。我在 model.shape_list(x) 中添加: 下面一行 打印(“_”,end ='',flush = True) 至少对于正在构建以将其本地化到机器的模块,您可以查看各种“进度条”。