在 Flux (Julia) 中运行 VGG19 模型时出现 MethodError

问题描述

下面提到的代码取自model-zoo。我正在尝试使用通量库在 julia 中运行 vgg19 tutorial

代码

#model
using Flux
vgg19() = Chain(            
    Conv((3,3),3 => 64,relu,pad=(1,1),stride=(1,1)),Conv((3,64 => 64,MaxPool((2,2)),64 => 128,128 => 128,128 => 256,256 => 256,256 => 512,512 => 512,Batchnorm(512),flatten,Dense(512,4096,relu),Dropout(0.5),Dense(4096,10),softmax
)

#data

using MLDatasets: CIFAR10
using Flux: onehotbatch
# Data comes pre-normalized in Julia
trainX,trainY = CIFAR10.traindata(Float64)
testX,testY = CIFAR10.testdata(Float64)
# One hot encode labels
trainY = onehotbatch(trainY,0:9)
testY = onehotbatch(testY,0:9)

#training

using Flux: crossentropy,@epochs
using Flux.Data: DataLoader
model = vgg19()
opt = Momentum(.001,.9)
loss(x,y) = crossentropy(model(x),y)
data = DataLoader(trainX,trainY,batchsize=64)
@epochs 100 Flux.train!(loss,params(model),data,opt)

当我在 IJulia 上执行这个文件时,抛出以下错误

MethodError: no method matching ∇maxpool(::Array{Float32,4},::Array{Float64,::PoolDims{2,(2,2),(0,0),(1,1)})
Closest candidates are:
  ∇maxpool(::AbstractArray{T,N},!Matched::AbstractArray{T,::PoolDims; kwargs...) where {T,N}

请针对此错误提出一些解决方案,如果可能,请提供简要说明或参考。 提前致谢!

解决方法

正如@mcabbott 所提到的,问题与数据的输入类型有关。这可以通过将 typeFloat64 更改为 Float32 来解决,#data 部分下面提到的参数。

trainX,trainY = CIFAR10.traindata(Float32)
testX,testY = CIFAR10.testdata(Float32)