问题描述
我正在尝试使用DL4J为MNISTAutoencoder
示例实现语义哈希。如何对中间层激活进行二值化?在理想情况下,我正在寻找对网络设置进行一些更改的方法,以使(几乎)开箱即用的中间层二进制激活。另外,我对一些“收据”满意以使当前的RELU激活二进制化。就泛化能力而言,这两种方法中哪一种是有利的?
我当前的网络设置为:
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345)
.weightinit(Weightinit.XAVIER)
.updater(new AdaGrad(0.05))
.activation(Activation.RELU)
.l2(0.0001)
.list()
.layer(new DenseLayer.Builder().nIn(784).nOut(250)
.build())
.layer(new DenseLayer.Builder().nIn(250).nOut(10)
.build())
.layer(new DenseLayer.Builder().nIn(10).nOut(250)
.build())
.layer(new OutputLayer.Builder().nIn(250).nOut(784)
.activation(Activation.LEAKYRELU)
.lossFunction(LossFunctions.LossFunction.MSE)
.build())
.build();
30个纪元后,典型的中间层激活如下:
[[ 11.3044,12.3678,7.3547,1.6518,1.0068,5.4340,2.1388,2.0708,2.5764]]
[[ 9.9051,12.5345,11.1941,4.7900,1.2935,7.9786,4.1915,3.1802,7.5659]]
[[ 6.4629,11.1013,10.8903,5.4528,0.8009,9.4881,3.6684,6.4524,7.2334]]
[[ 2.3953,0.2429,3.7125,4.1561,0.8607,11.2486,7.0178,2.8771,2.1996]]
[[ 0,1.6378,0.8993,0.3347,0.7708,3.7053,1.6704,2.1380]]
[[ 0,1.5158,0.7937,0.8190,4.7548,0.0655,1.4635,1.8173]]
[[ 6.8344,5.9989,10.1286,2.8528,1.1178,9.1865,10.3677,5.3564,4.3420]]
[[ 7.0942,7.0364,4.8538,0.5096,0.0442,8.4336,8.2783,5.6474,3.8944]]
[[ 3.6895,14.9696,6.5351,8.0446,12.7816,12.7445,7.8495,3.8600]]
解决方法
这可以通过向中间层分配自定义IActivation
函数来建立。例如:
public static class ActivationBinary extends BaseActivationFunction {
public INDArray getActivation(INDArray in,boolean training) {
in.replaceWhere(Nd4j.ones(in.length()).muli(-1),new LessThan(0));
in.replaceWhere(Nd4j.ones(in.length()),new GreaterThanOrEqual(0));
return in;
}
public org.nd4j.common.primitives.Pair<INDArray,INDArray> backprop(INDArray in,INDArray epsilon) {
this.assertShape(in,epsilon);
Nd4j.getExecutioner().execAndReturn(new TanhDerivative(in,epsilon,in)); // tanh's gradient is a reasonable approximation
return new org.nd4j.common.primitives.Pair(in,(Object)null);
}
public int hashCode() {
return 1;
}
public boolean equals(Object obj) {
return obj instanceof ActivationBinary;
}
public String toString() {
return "Binary";
}
}