查找与Pytorch GRU功能等效的TensorFlow

问题描述

我对如何在TensorFlow中重建以下Pytorch代码感到困惑。它使用输入大小x和隐藏大小h来创建GRU层

import torch
torch.nn.GRU(64,64*2,batch_first=True,return_state=True) 

本能地,我首先尝试了以下方法

import tensorflow as tf
tf.keras.layers.GRU(64,return_state=True)

但是,我意识到它并没有真正解决h或隐藏的大小。在这种情况下我该怎么办?

解决方法

在您的 tensorflow 示例中,隐藏大小为 64。要获得等效项,您应该使用

import tensorflow as tf
tf.keras.layers.GRU(64*2,return_state=True)

这是因为keras层不需要你指定你的输入大小(在这个例子中是64);这取决于您第一次构建或运行模型的时间。