此代码可以识别MNIST集吗? K-NN方法

问题描述

我不确定下面的代码是否会执行,因为它长期停留在“计算预测”上。如果无法正常工作,我该怎么做?

import struct
import matplotlib.pyplot as plt
import numpy as np
import os
from scipy.special import expit
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score



clf = KNeighborsClassifier()

def load_data():
    with open('train-labels-idx1-ubyte','rb') as labels:
        magic,n = struct.unpack('>II',labels.read(8))
        train_labels = np.fromfile(labels,dtype=np.uint8)
    with open('train-images-idx3-ubyte','rb') as imgs:
        magic,num,nrows,ncols = struct.unpack('>IIII',imgs.read(16))
        train_images = np.fromfile(imgs,dtype=np.uint8).reshape(num,784)
    with open('t10k-labels-idx1-ubyte',labels.read(8))
        test_labels = np.fromfile(labels,dtype=np.uint8)
    with open('t10k-images-idx3-ubyte',imgs.read(16))
        test_images = np.fromfile(imgs,784)
    return train_images,train_labels,test_images,test_labels


def knn(train_x,train_y,test_x,test_y):
    clf.fit(train_x,train_y)
    print("Compute predictions")
    predicted = clf.predict(test_x)
    print("Accuracy: ",accuracy_score(test_y,predicted))

train_x,test_y = load_data()
knn(train_x,test_y)

解决方法

它长期停留在“计算预测”上

在对整个数据集运行之前,我建议您使用一组非常有限的数据来测试一切是否正常。这样,您可以确保代码有意义。

一旦测试了代码,就可以安全地进行整个数据集的训练。

这样,您可以轻松地辨别代码是由于某些代码问题还是由于数据量而花费了很长时间(也许代码还可以,但是您可能会意识到,例如10个样本,比您愿意/可以等待的时间更长,因此您可以进行相应的调整-否则您正在处理的黑匣子就太多了。

话虽如此,如果代码还可以,但是花费的时间太长,我也建议像Soumya一样尝试在Colab上运行。您在那里有一些好的硬件,可以进行长达12个小时的会话,并且可以让您的PC同时免费测试其他代码!