重塑归纳变量-GPflow

问题描述

我有一个SGPR模型:

import numpy as np
import gpflow

X,Y = np.random.randn(50,2),np.random.randn(50,1)
Z1 = np.random.randn(13,2)

k = gpflow.kernels.SquaredExponential()
m = gpflow.models.SGPR(data=(X,Y),kernel=k,inducing_variable=Z1)

我想分配归纳变量,但是形状不同,例如:

Z2 = np.random.randn(29,2)
m.inducing_variable.Z.assign(Z2)

但是,如果我这样做,我会得到:

ValueError: Shapes (13,2) and (29,2) are incompatible

是否可以在不重新定义模型的情况下重新分配归纳变量?

上下文:我不想使用归纳变量优化模型,而是想在不优化归纳变量的情况下优化模型,而是在优化的每个步骤中手动重新分配归纳变量。

解决方法

更新https://github.com/GPflow/GPflow/pull/1594解决了此问题,该问题将成为下一个GPflow修补程序版本(2.1.4)的一部分。

有了此修复程序,您不需要自定义类。您需要做的就是在第一个维度上使用None显式设置静态形状:

inducing_variable = gpflow.inducing_variables.InducingPoints(
    tf.Variable(
        Z1,# initial value
        trainable=False,# True does not work - see Note below
        shape=(None,Z1.shape[1]),# or even tf.TensorShape(None)
        dtype=gpflow.default_float(),# required due to tf's 32bit default
    )
)
m = gpflow.models.SGPR(data=(X,Y),kernel=k,inducing_variable=inducing_variable)

然后m.inducing_variable.Z.assign(Z2)应该可以正常工作。

注意,在这种情况下,Z 不可训练,因为TensorFlow优化器需要在构造时知道形状,并且不支持动态形状


目前(从GPflow 2.1.2开始),尽管在原则上是可行的,但是没有内置的方法来更改SGPR的归纳变量的形状。您可以通过自己的归纳变量类获得所需的内容:

class VariableInducingPoints(gpflow.inducing_variables.InducingPoints):
     def __init__(self,Z,name=None):
         super().__init__(Z,name=name)
         # overwrite with Variable with None as first element in shape so
         # we can assign arrays with arbitrary length along this dimension:
         self.Z = tf.Variable(Z,dtype=gpflow.default_float(),shape=(None,Z.shape[1])
         )

     def __len__(self):
         return tf.shape(self.Z)[0]  # dynamic shape
         # instead of the static shape returned by the InducingPoints parent class

然后做

m = gpflow.models.SGPR(
    data=(X,inducing_variable=VariableInducingPoints(Z1)
)

相反。然后您的m.inducing_variable.Z.assign()应该可以按照您的要求工作。

(对于SVGP,归纳变量的大小以及由q_muq_sqrt定义的分布必须匹配,并且在构造时必须知道,因此在这种情况下更改归纳变量的数量并非易事。)

相关问答

Selenium Web驱动程序和Java。元素在(x,y)点处不可单击。其...
Python-如何使用点“。” 访问字典成员?
Java 字符串是不可变的。到底是什么意思?
Java中的“ final”关键字如何工作?(我仍然可以修改对象。...
“loop:”在Java代码中。这是什么,为什么要编译?
java.lang.ClassNotFoundException:sun.jdbc.odbc.JdbcOdbc...