问题描述
我不明白为什么我的感知器不起作用,它没有从训练数据中学习,我不确定为什么。每个时代都会出现错误并收敛,我在这里遗漏了什么吗? 这是源代码。
import numpy as np
import matplotlib.pyplot as plt
trainData = np.genfromtxt('train.data',delimiter=",")
testData = np.genfromtxt('test.data',")
class Perceptron():
def __init__(self,learnRate,maxEpoch):
self.learnRate = learnRate
self.maxEpoch = maxEpoch
def train(self,data):
np.random.seed(2)
x0 = np.ones((len(data),1))
w0 = np.zeros((1))
weights = np.random.rand((len(data[0]))-1)
weights = np.hstack([w0,weights])
data = np.hstack([x0,data])
np.random.shuffle(data)
label = data[:,5:]
data = data[:,:5]
iter = 0
for epoch in range(self.maxEpoch): #trainign
errors = 0
rowCount = 0
for row in data:
activation = np.inner(row,weights)
y = label[rowCount]# Class label
rowCount += 1
if y * activation <= 0:
for wx in weights:
wx = wx + y * row[int(wx)]
weights[0] = weights[0] + y
errors += 1
iter += 1
print(f'Errors total: {errors}')
return weights
def classPicker(classX,classY):
set1 = trainData[:40,:4]
set2 = trainData[40:80,:4]
set3 = trainData[80:120,:4]
class0 = np.full((40,1),-1)
class1 = np.ones((40,1))
if classX == 1 and classY == 2:
return np.vstack([np.hstack([set1,class0]),np.hstack([set2,class1])])
if classX == 2 and classY == 3:
return np.vstack([np.hstack([set2,np.hstack([set3,class1])])
if classX == 1 and classY == 3:
return np.vstack([np.hstack([set1,class1])])
#Class 1 & 2
pececptron1 = Perceptron(1,20)
trainDataset = (classPicker(1,2))
pececptron1.train(trainDataset)
#Class 2 & 3
#pececptron2 = Perceptron(1,1)
#trainDataset = classPicker(2,3)
#pececptron2.train(trainDataset)
#Class 1 & 3
#pececptron3 = Perceptron(1,1)
#trainDataset = classPicker(1,3)
#pececptron3.train(trainDataset)
数据集是鸢尾花数据集的修改版本。 数据数组在第 25 行看起来像这样,就在训练循环之前。
label = 40 x 1 array of 1 or -1 depending on class
weights = [0,0.4359949,0.02592623,0.54966248,0.43532239] first entry in all weights = w0 aka bias
data = [1,5,3.4,1.5,0.2] first entry in all = x0 (1) for the bias
解决方法
暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!
如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。
小编邮箱:dio#foxmail.com (将#修改为@)