问题描述
下面提到的代码取自model-zoo。我正在尝试使用通量库在 julia 中运行 vgg19 tutorial。
代码:
#model
using Flux
vgg19() = Chain(
Conv((3,3),3 => 64,relu,pad=(1,1),stride=(1,1)),Conv((3,64 => 64,MaxPool((2,2)),64 => 128,128 => 128,128 => 256,256 => 256,256 => 512,512 => 512,Batchnorm(512),flatten,Dense(512,4096,relu),Dropout(0.5),Dense(4096,10),softmax
)
#data
using MLDatasets: CIFAR10
using Flux: onehotbatch
# Data comes pre-normalized in Julia
trainX,trainY = CIFAR10.traindata(Float64)
testX,testY = CIFAR10.testdata(Float64)
# One hot encode labels
trainY = onehotbatch(trainY,0:9)
testY = onehotbatch(testY,0:9)
#training
using Flux: crossentropy,@epochs
using Flux.Data: DataLoader
model = vgg19()
opt = Momentum(.001,.9)
loss(x,y) = crossentropy(model(x),y)
data = DataLoader(trainX,trainY,batchsize=64)
@epochs 100 Flux.train!(loss,params(model),data,opt)
MethodError: no method matching ∇maxpool(::Array{Float32,4},::Array{Float64,::PoolDims{2,(2,2),(0,0),(1,1)})
Closest candidates are:
∇maxpool(::AbstractArray{T,N},!Matched::AbstractArray{T,::PoolDims; kwargs...) where {T,N}
请针对此错误提出一些解决方案,如果可能,请提供简要说明或参考。 提前致谢!
解决方法
正如@mcabbott 所提到的,问题与数据的输入类型有关。这可以通过将 type
从 Float64
更改为 Float32
来解决,#data
部分下面提到的参数。
trainX,trainY = CIFAR10.traindata(Float32)
testX,testY = CIFAR10.testdata(Float32)