如何在Tensorflow中实现条件双射主要是如何使用kwargs问题,因为这会给我一个错误

问题描述

我正在尝试实现条件双射器。如果您不知道那并不重要,但实际上我的代码是这样的:

import tensorflow as tf
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
import numpy as np
from math import log,exp
tfb = tfp.bijectors
import pickle as pk
import statsmodels.api as sm
import statsmodels.formula.api as smf
import pandas as pd
import os

class varFamBij2(tf.keras.models.Model):
    def __init__(self,*,output_dim,**kwargs): #** additional arguments for the super class
        super().__init__(**kwargs)
        self.output_dim = output_dim
        num_bijectors = 5
        bijectors=[]
        for i in range(num_bijectors):
            bijectors.append(tfb.MaskedAutoregressiveFlow(tfp.bijectors.AutoregressiveNetwork(1,event_shape=self.output_dim,hidden_units=[32,32],conditional=True,conditional_event_shape= 13)))
            bijectors.append(tfb.Permute(permutation=[1,0]))
        bijectors.append(tfb.MaskedAutoregressiveFlow(tfp.bijectors.AutoregressiveNetwork(1,conditional_event_shape= 13)))

        #A bijector is formed by chaining together many layers of bijectors
        self.bijector = tfb.Chain(bijectors)


          
x1 = tf.ones([2])
x2 = tf.ones([13])
mod11 = varFamBij2(output_dim=2)
predictions = mod11.bijector.forward(x1,conditional_input = x2)

conditional_input = x2是个怪人。本质上,我会收到此错误

/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/bijectors/masked_autoregressive.py in call(self,x,conditional_input)
   1049       if self._conditional:
   1050         if conditional_input is None:
-> 1051           raise ValueError('`conditional_input` must be passed as a named '
   1052                            'argument')
   1053         conditional_input = tf.convert_to_tensor(

ValueError: 'conditional_input' must be passed as a named argument

问题是函数call(self,conditional_input)具有此conditional_input,假定每个TF文档将其作为** kwargs送入(至少根据我对kwargs的理解很差),我认为** kwargs不会作为条件输入条件输入(因为conditional_input的认值为None,我认为这是引发错误的原因。)

我认为不需要TensorFlow的超详细知识来回答这个问题。我认为我无法理解和使用kwargs是导致该程序无法正常工作的原因。好奇是否有人可以建议使用kwargs(或其他方法)的方式,以便调用方法将接受我的conditional_input。谢谢,

Cameron

解决方法

在TensorFlow Probability GitHub页面上查看SiegeLordEx的答案。这样就解决了:https://github.com/tensorflow/probability/issues/1159

从本质上讲,必须为条件双射器编写以下代码:

chain = Chain([Bijector1(name='b1'),Bijector2(name='b2')]) 
y = chain.forward(x,b1={'arg': 1},b2={'arg': 2})