问题描述
我正在尝试创建一个卷积神经网络,以在 Julia 中使用 Flux 对 MNIST 数据进行分类。我从以下链接下载了 csv 格式的数据:https://www.kaggle.com/oddrationale/mnist-in-csv。 我的代码如下:
using CSV
using DataFrames
using Images
using Base.Iterators: partition
using Flux
sqrt(x) = convert(Int64,floor(x^0.5))
matrixize(x) = cat([x[i:i+sqrt(length(x))-1] for i in 1:sqrt(length(x)):length(x)]...,dims=2)'
img(x) = Gray.(x)
process(row) = img(matrixize(row[2:length(row)])),convert(Int,255*row[1]) |> gpu
train_data = DataFrame(CSV.File("MNIST_data/mnist_train.csv"))
train_X = []
train_y = Int64[]
for row in eachrow(train_data)
row = convert(Array,row)
row = [i/255 for i in row]
X,y = process(row)
push!(train_X,X)
push!(train_y,y)
end
train_y = Flux.onehotbatch(train_y,0:9)
train = [(cat(float.(train_X[i])...,dims=4),train_y[:,i]) for i in partition(1:size(train_data)[1],1000)] |> gpu
test_data = DataFrame(CSV.File("MNIST_data/mnist_train.csv"))
test_X = []
test_y = Int64[]
for row in eachrow(test_data)
row = convert(Array,y = process(row)
push!(test_X,X)
push!(test_y,y)
end
test_y = Flux.onehotbatch(test_y,0:9)
println("Pre-processing Complete")
m = Chain(
Conv((5,5),1=>16,relu),MaxPool((2,2)),Conv((5,16=>8,Flux.flatten,Dense(200,100),Dense(100,10),Flux.softmax
) |> gpu
loss(x,y) = Flux.Losses.crossentropy(m(x),y) |> gpu
opt = Momentum(0.01) |> gpu
println("Model Creation Complete")
println()
epochs = 10
for i in 1:epochs
for j in train
gs = gradient(params(m)) do
l = loss(j...)
end
update!(opt,params(m),gs)
end
@show accuracy(test_X,test_y)
end
println()
@show accuracy(test_X,test_y)
当我检查 test_X、test_y、train_X 和 train_y 的值时,它们都采用适当的格式,但是当我尝试运行代码时出现此错误:
┌ Warning: Slow fallback implementation invoked for conv! You probably don't want this; check your datatypes.
│ yT = Float64
│ T1 = Gray{Float64}
│ T2 = Float32
└ @ NNlib /Users/satvikd/.julia/packages/NNlib/PI8Xh/src/conv.jl:206
┌ Warning: Slow fallback implementation invoked for conv! You probably don't want this; check your datatypes.
│ yT = Float64
│ T1 = Float64
│ T2 = Float32
└ @ NNlib /Users/satvikd/.julia/packages/NNlib/PI8Xh/src/conv.jl:206
DimensionMismatch("A has dimensions (100,200) but B has dimensions (128,1000)")
堆栈跟踪指的是第 55 行,即带有渐变的那一行。 任何帮助将不胜感激。
解决方法
看来您需要检查您的类型,因为其中一个是 float32 而另一个是 float64。通常,flux 默认使用 Float32,因为这对于深度学习任务来说已经足够精确了。您也可以使用 Flux.f64/f32。