Tensorflow概率:从联合分布中检索特定随机变量

问题描述

我是张量流概率的新手。 我正在建立一个层次模型,为此我使用了 JointdistributionSequential API:

jds = tfp.distributions.JointdistributionSequential(
[
    # mu_g ~ uniform on sphere
    tfp.distributions.VonMisesFisher(
        mean_direction= [1] + [0]*(D-1),concentration=0,validate_args=True,name="mu_g"
    ),# epsilon ~ Exponential
    tfp.distributions.Exponential(
        rate=1,name="epsilon"
    ),# mu_s ~ von Mises Fisher centered on mu_g
    lambda epsilon,mu_g: tfp.distributions.VonMisesFisher(
        mean_direction=mu_g,concentration=np.array(
            [epsilon]*S
        ),name="mu_s"
    ),# sigma ~ Exponential
    tfp.distributions.Exponential(
        rate=1,name="sigma"
    ),# mu_t_s ~ von Mises Fisher centered on mu_s
    lambda sigma,mu_s: tfp.distributions.VonMisesFisher(
        mean_direction=mu_s,concentration=np.array(
            [
                [sigma]*S
            ]*T
        ),name="mu_t_s"
    ),# kappa ~ Exponential
    tfp.distributions.Exponential(
        rate=1,name="kappa"
    ),# x_t_s ~ mixture of L groups of vMF
    lambda kappa,mu_t_s: tfp.distributions.VonMisesFisher(
        mean_direction=mu_t_s,concentration=np.array(
            [
                [
                    [
                        kappa
                    ]*S
                ]*T
            ]*N
        ),name="x_t_s
    )            
]
)

然后我打算使用 Mixture API创建这些模型的混合体:

l = tfp.distributions.Categorical(
probs=np.array(
    [
        [
            [
                [1.0/L]*L
            ]*S
        ]*T 
    ]*N               
),name="l"
)

mixture = tfd.Mixture(
cat=l,components=[
    jds
] * L,validate_args=True
)

这不起作用。我打算混合使用的是批次模型(N,T,S)的分层模型“ x_t_s ”的“末端”随机变量。我想我需要将它们输入到 components 参数中进行混合。问题是我无法轻松地从 model 对象中检索这些变量。

有人看到解决这个问题的方法吗?

请注意,我尝试使用 jds.model [-1] 而不是 jds ,但这指向了lambda函数在这里我不需要

解决方法

这里有很多想法。

  1. 考虑SphericalUniform进行第一次分发。
  2. 对于同一类型的Mixture,请考虑使用MixtureSameFamily
  3. 将混合物放入分层模型中。即不是最后一个分发是vMF,而是可能是MixtureSameFamily(Categorical(...),VonMisesFisher(...))
  4. 如果以后要访问组件,可以调用ds,xs = jds.sample_distributions(),然后查看ds[-1].component_distribution

随时通过电子邮件发送[email protected]并附带问题。