Python的对数似然和梯度函数实现

问题描述

关于科学论文https://arxiv.org/abs/1704.04289,我正在尝试实现第7.3节“优化超参数”。具体来说,论文第25页上的等式35。

对数似然函数负值似乎比通常的对数回归复杂。我尝试按照下面的代码为log reg实现负对数似然性和梯度下降。

import numpy as np
import pandas as pd
import sklearn
import matplotlib.pyplot as plt
%matplotlib inline

#simulating data to fit logistic regression model
np.random.seed(12)
num_observations = 5000

x1 = np.random.multivariate_normal([0,0],[[1,.75],[.75,1]],num_observations)
x2 = np.random.multivariate_normal([1,4],num_observations)

simulated_features = np.vstack((x1,x2)).astype(np.float32)
simulated_labels = np.hstack((np.zeros(num_observations),np.ones(num_observations)))

plt.figure(figsize=(12,8))

plt.scatter(simulated_features[:,simulated_features[:,1],c = simulated_labels,alpha = .4)

#add a column of ones to deal with bias term
ones = np.ones((num_observations*2,1))

Xb = np.hstack((ones,simulated_features))

#Activation Function

def sigmoid(scores):
    return 1 / (1 + np.exp(-scores))


#log-likelihood function

def log_likelihood(features,target,weights):
    #model output
    scores = np.dot(features,weights)
    nll = np.sum(target*scores - np.log(1 + np.exp(scores)))
    return nll

def log_reg(features,num_steps,learning_rate):
    weights = np.zeros(features.shape[1])
    
    for step in range(num_steps):
        score = np.dot(features,weights)
        predictions = sigmoid(scores)
        
        #update weights with gradient
        error = target - predictions
        gradient = np.dot(features.T,error)
        weights += learning_rate * gradient
        
        if step % 10000 == 0:
            print(log_likelihood(features,weights))
        
    return weights

我在这里面临的最大挑战是从本文的方程式中实现术语lambda,DK,theta(dk)和theta(dyn)。从理论上讲我理解了实现,并且能够在纸上手工解决它,但是我发现在使用一些模拟数据时很难在python上实现(如我的代码所示)。谁能指导我如何实现这一目标?

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)

相关问答

错误1:Request method ‘DELETE‘ not supported 错误还原:...
错误1:启动docker镜像时报错:Error response from daemon:...
错误1:private field ‘xxx‘ is never assigned 按Alt...
报错如下,通过源不能下载,最后警告pip需升级版本 Requirem...