1,高斯分布
高斯分布(Gaussian distribution)又称正态分布(normal distribution):若随机变量 服从一个数学期望为 ,方差为 的正态分布,则记为 。高斯分布的概率密度函数为正态分布,期望值 决定了其位置,其标准差 决定了分布的幅度。当 ,标准差 时的正态分布是标准正态分布。
2,KL散度
假设给定离散事件 , 则我们有以下定义:
- 概率:
- 信息:对 取对数,加符号得正值:,概率越高,包含的信息小,因为事件越来越确定。相反,概率越低,包含的信息越多,因为事件具有很大的不确定性。
- 香农熵: 对 平均:,熵是信息的平均,直观上,香农熵是信息在同一分布 下的平均。
- 交叉熵: 对 平均:,熵是信息的平均,直观上,交叉熵是信息在不同分布下的平均。
- KL散度(相对熵):相对熵 = 交叉熵 - 香农熵,非对称 ,亦不满足三角不等式,故不是距离。
若为连续事件:
- 香农熵:
- 交叉熵:
- 相对熵:
3,KL散度衡量两个高斯分布相似性
高斯分布为连续型分布,故
设 ,
故:
根据交叉熵公式反推:
由于
由于
由于标准差求和为0,故:
其中:p,q为torch.distributions.normal 表示正态分布。
def _kl_normal_normal(p, q): var_ratio = (p.scale / q.scale).pow(2) t1 = ((p.loc - q.loc) / q.scale).pow(2) return 0.5 * (var_ratio + t1 - 1 - var_ratio.log())
4,Wasserstein距离
两个多元高斯分布之间的2阶Wasserstein距离:
如果采用距离函数是欧几里得距离的话,那么两个分布之间的2阶Wasserstein距离是:
当协方差矩阵可以互换
当 与 都是对称矩阵时,有 :
此时:
其中:p,q为torch.distributions.normal 表示正态分布。
def ws_normal_normal(p, q): u = p.loc - q.loc p1 = torch.sum(torch.pow(u, 2), 1) p2 = torch.sum(torch.pow(torch.pow(p.scale, 1 / 2) - torch.pow(q.scale, 1 / 2), 2), 1) result = (p1 + p2).mean() return result