在自定义训练循环中计算梯度、性能 TF 与 Torch 的差异

问题描述

我尝试将计算分子结构中的力和能量的 NN 模型的 pytorch 实现转换为 TensorFlow。这需要自定义训练循环和自定义损失函数,因此我在下面实现了不同的一步训练函数

  1. 首先使用嵌套渐变胶带。
def calc_gradients(D_train_batch,E_train_batch,F_train_batch,opt):
    
    #set up gradient tape scope in order to track gradients of both d(Loss)/d(Weights)
    #and d(output)/d(input)
     with tf.GradientTape() as tape1:
          with tf.GradientTape() as tape2:
              #set gradient tape to watch Tensor
              tape2.watch(D_train_batch)
              #pass D thru model to get predicted energy vals
              E_pred = model(D_train_batch,training=True)
                  
          df_dD_train_batch = tape2.gradient(E_pred,D_train_batch) 
          #matrix mult of -Grad_D(f) x Grad_r(D)
          F_pred = -tf.einsum('ijkl,il->ijk',dD_dr_train_batch,df_dD_train_batch)
          #calculate loss value
          loss = force_energy_loss(E_pred,F_pred,F_train_batch)
          
          
     
     grads = tape1.gradient(loss,model.trainable_weights)
     opt.apply_gradients(zip(grads,model.trainable_weights))
  1. 使用渐变胶带的其他尝试(persistent = true)
def calc_gradients_persistent(D_train_batch,opt):
#set up gradient tape scope in order to track gradients of both d(Loss)/d(Weights)
        #and d(output)/d(input)
        with tf.GradientTape(persistent = True) as outer:
            
            #set gradient tape to watch Tensor
            outer.watch(D_train_batch)
            
            #output values from model,set trainable to be true to get 
            #model.trainable_weights out
            E_pred = model(D_train_batch,training=True)
            
            #set gradient tape to watch trainable weights
            outer.watch(model.trainable_weights)
            
            #get gradient of output (f/E_pred) w.r.t input (D/D_train_batch) and cast to double
            df_dD_train_batch = outer.gradient(E_pred,D_train_batch)
            
            #matrix mult of -Grad_D(f) x Grad_r(D)
            F_pred = -tf.einsum('ijkl,df_dD_train_batch)

            #calculate loss value
            loss = force_energy_loss(E_pred,F_train_batch)
        
        #get gradient of loss w.r.t to trainable weights for back propogation
        grads = outer.gradient(loss,model.trainable_weights)
        
        #updates weights using the optimizer and the gradients (grads)
        opt.apply_gradients(zip(grads,model.trainable_weights)) 

这些是对 pytorch 代码的尝试翻译

# Forward pass: Predict energies from the descriptor input
        E_train_pred_batch = model(D_train_batch)

        # Get derivatives of model output with respect to input variables. The
        # torch.autograd.grad-function can be used for this,as it returns the
        # gradients of the input with respect to outputs. It is very important
        # to set the create_graph=True in this case. Without it the derivatives
        # of the NN parameters with respect to the loss from the force error
        # will not be populated (=the force error will not affect the
        # training),but the model will still run fine without errors.
        df_dD_train_batch = torch.autograd.grad(
            outputs=E_train_pred_batch,inputs=D_train_batch,grad_outputs=torch.ones_like(E_train_pred_batch),create_graph=True,)[0]

        # Get derivatives of input variables (=descriptor) with respect to atom
        # positions = forces
        F_train_pred_batch = -torch.einsum('ijkl,df_dD_train_batch)

        # Zero gradients,perform a backward pass,and update the weights.
        # D_train_batch.grad.data.zero_()
        optimizer.zero_grad()
        loss = energy_force_loss(E_train_pred_batch,F_train_pred_batch,F_train_batch)
        loss.backward()
        optimizer.step()

来自 https://singroup.github.io/dscribe/latest/tutorials/machine_learning/forces_and_energies.html 的 Dscribe 库教程

问题

与运行 pytorch 版本相比,使用任一版本的 TF 实现在预测准确性方面都有巨大的损失。我想知道,我是否可能误解了 pytorch 代码错误地翻译了,如果是这样,我的差异在哪里?

附注 模型直接计算能量 E,从中我们使用 E w.r.t D 的梯度来计算力 F。损失函数是力和能量的 MSE 的加权和。

解决方法

这些方法实际上是相同的,我的错误是在其他地方产生了不同的结果。对于任何尝试实现 TensorFlow 版本的人来说,嵌套梯度磁带的速度大约快 2 倍,至少在这种情况下,并且还确保将函数包装在 @tf.function 中以便在急切执行上使用图,加速大约是 10 倍。

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...