问题描述
我正在尝试实现条件双射器。如果您不知道那并不重要,但实际上我的代码是这样的:
import tensorflow as tf
import tensorflow as tf
import tensorflow_probability as tfp
from tensorflow_probability import distributions as tfd
import numpy as np
from math import log,exp
tfb = tfp.bijectors
import pickle as pk
import statsmodels.api as sm
import statsmodels.formula.api as smf
import pandas as pd
import os
class varFamBij2(tf.keras.models.Model):
def __init__(self,*,output_dim,**kwargs): #** additional arguments for the super class
super().__init__(**kwargs)
self.output_dim = output_dim
num_bijectors = 5
bijectors=[]
for i in range(num_bijectors):
bijectors.append(tfb.MaskedAutoregressiveFlow(tfp.bijectors.AutoregressiveNetwork(1,event_shape=self.output_dim,hidden_units=[32,32],conditional=True,conditional_event_shape= 13)))
bijectors.append(tfb.Permute(permutation=[1,0]))
bijectors.append(tfb.MaskedAutoregressiveFlow(tfp.bijectors.AutoregressiveNetwork(1,conditional_event_shape= 13)))
#A bijector is formed by chaining together many layers of bijectors
self.bijector = tfb.Chain(bijectors)
x1 = tf.ones([2])
x2 = tf.ones([13])
mod11 = varFamBij2(output_dim=2)
predictions = mod11.bijector.forward(x1,conditional_input = x2)
conditional_input = x2
是个怪人。本质上,我会收到此错误:
/usr/local/lib/python3.6/dist-packages/tensorflow_probability/python/bijectors/masked_autoregressive.py in call(self,x,conditional_input)
1049 if self._conditional:
1050 if conditional_input is None:
-> 1051 raise ValueError('`conditional_input` must be passed as a named '
1052 'argument')
1053 conditional_input = tf.convert_to_tensor(
ValueError: 'conditional_input' must be passed as a named argument
问题是函数call(self,conditional_input)
具有此conditional_input,假定每个TF文档将其作为** kwargs送入(至少根据我对kwargs的理解很差),我认为** kwargs不会作为条件输入条件输入(因为conditional_input的默认值为None,我认为这是引发错误的原因。)
我认为不需要TensorFlow的超详细知识来回答这个问题。我认为我无法理解和使用kwargs是导致该程序无法正常工作的原因。好奇是否有人可以建议使用kwargs(或其他方法)的方式,以便调用方法将接受我的conditional_input。谢谢,
Cameron
解决方法
在TensorFlow Probability GitHub页面上查看SiegeLordEx的答案。这样就解决了:https://github.com/tensorflow/probability/issues/1159
从本质上讲,必须为条件双射器编写以下代码:
chain = Chain([Bijector1(name='b1'),Bijector2(name='b2')])
y = chain.forward(x,b1={'arg': 1},b2={'arg': 2})