问题描述
我正在尝试使用gpt-2生成文本。即使运行Tensorflow 2.0 code upgrade script,我仍然遇到兼容性错误。
我遵循的步骤:
-
克隆repo
-
从现在开始,按照DEVELOPERS.md中的指示进行操作
-
在/ src中的文件上运行upgrade script
-
在终端运行中:
sudo docker build --tag gpt-2 -f Dockerfile.gpu .
-
构建完成后,运行:
sudo docker run --runtime=nvidia -it gpt-2 bash
-
输入
python3 src/generate_unconditional_samples.py | tee /tmp/samples
-
获取此追溯:
Traceback (most recent call last): File "src/generate_unconditional_samples.py",line 9,in <module> import model,sample,encoder File "/gpt-2/src/model.py",line 4,in <module> from tensorboard.plugins.hparams.api import HParam ImportError: No module named 'tensorboard.plugins.hparams' root@f8bdde043f91:/gpt-2# python3 src/generate_unconditional_samples.py | tee /tmp/samples Traceback (most recent call last): File "src/generate_unconditional_samples.py",in <module> import model,in <module> from tensorboard.plugins.hparams.api import HParam ImportError: No module named 'tensorboard.plugins.hparams'```
HParams似乎已被弃用,并且Tensorflow 2.0中的新版本称为HParam。但是,参数不同。在model.py
中,参数的实例化如下:
def default_hparams():
return HParams(
n_vocab=0,n_ctx=1024,n_embd=768,n_head=12,n_layer=12,)
Tensorflow 2.0中似乎没有任何1:1转换。有谁知道如何使gpt-2与Tensorflow 2.0兼容?
我的GPU是NVIDIA 20xx。
谢谢。
解决方法
如果要查看我的1.x fork,请在此处进行编译: https://github.com/timschott/gpt-2
,我遇到了同样的问题,但是通过在文件夹中创建一个单独的hparams.py文件并使用以下内容填充了该问题:https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/hparam.py
然后在您的model.py中,可以添加以下代码并将其替换掉:
import tensorflow as tf
from tensorflow.contrib.training import HParams
与此:
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from hparams import HParams
然后,您将必须在“ tf”和模块之间添加“ compat.v1”(如果这就是它的名字)...例如,如果它是“ tf.Session”,则将其更改为“ tf.compat.v1” .Session”,或者如果是“ tf.placeholder”,请将其更改为“ tf.compat.v1.placeholder”,等等。
我在尝试降级到tensorflow-gpu 1.13之后仍然执行此操作,而gpt-2仍然无法正常工作。在3.6版Python中运行环境的情况也是如此。
P.S。这是我的第一个答案,不确定我是否正确设置了格式,但我也会继续学习。
,自动升级代码很少开箱即用。使用该存储库,您应该最多使用 Tensorflow=1.15
。
如果你真的想要 Tensorflow=2
,你可以看看这个 repo:https://github.com/akanyaani/gpt-2-tensorflow2.0
注意:您不会获得预训练的模型(对于普通用户来说,这可能是 gpt2 中最有趣的部分)。这意味着无法访问他们的 1554M 或 778M 型号。
我不知道将预训练模型从 1.15 自动升级到 2.3 或诸如此类的方法。