问题描述
考虑以下累加器类型,它的工作方式类似于数组,您可以将其压入其中,但只能跟踪其均值:
mutable struct Accumulator{T}
data::T
count::Int64
end
function Base.push!(acc::Accumulator,term)
acc.data += term # <-- in-place addition
acc.count += 1
acc
end
mean(acc::Accumulator) = acc.data ./ acc.count
我希望它适用于T
是标量或数组类型。然而,
事实证明,对于T
是数组类型,push!
中的加法创建了一个临时类型。这是因为在Julia中,x+=a
是equivalent到x=x+a
,而且我怀疑Julia无法保证acc.data
和term
不具有别名。
一个简单的解决方法是将+=
替换为逐元素加法.+=
。但是,这将破坏标量类型,不允许这样做。因此,我想出解决此问题的唯一方法是添加以下形式的特殊化:
function Base.push!(acc::Accumulator,term::AbstractArray)
acc.data .+= term # <-- element-wise addition
acc.count += 1
acc
end
这有点丑陋,而且很脆弱……有人知道更好的方法吗,最好是通用的方法,而无需临时创建?
解决方法
奇怪的是,Number
s are iterable in Julia,但这似乎无济于事,因为setindex!
s没有Number
方法。
这是两种不同的方法。第一个使用iterator traits,第二个使用一些方法签名来解决极端情况。
迭代器特征
我们可以使用IteratorSize
特征来区分标量和向量。对于标量,Base.IteratorSize(x)
返回Base.HasShape{0}
。对于数组,Base.IteratorSize(x)
返回Base.HasShape{N}
,其中N
是数组的维数。
mutable struct Accumulator{T}
data::T
count::Int64
end
function Base.push!(acc::Accumulator{T},term::S) where {T,S}
_push_acc!(Base.IteratorSize(T),Base.IteratorSize(S),acc,term)
end
function _push_acc!(::Base.HasShape{0},::Base.HasShape{0},acc::Accumulator,term)
acc.data += term
acc.count += 1
acc
end
function _push_acc!(::Base.HasShape{N},::Base.HasShape{N},term) where {N}
acc.data .+= term
acc.count += 1
acc
end
function _push_acc!(::Base.HasShape{M},::Accumulator,::Any) where {M,N}
throw(ArgumentError("Accumulator and term have inconsistent shapes"))
end
在REPL上采取的行动:
julia> a = Accumulator(1,0)
Accumulator{Int64}(1,0)
julia> b = Accumulator([1,2],0)
Accumulator{Array{Int64,1}}([1,0)
julia> push!(a,42)
Accumulator{Int64}(43,1)
julia> push!(b,[3,4])
Accumulator{Array{Int64,1}}([4,6],1)
julia> push!(a,[5,6])
ERROR: ArgumentError: Accumulator and term have inconsistent shapes
Stacktrace:
[1] _push_acc!(::Base.HasShape{0},::Base.HasShape{1},::Accumulator{Int64},::Array{Int64,1}) at ...
[2] push!(::Accumulator{Int64},1}) at ...
[3] top-level scope at REPL[6]:1
julia> push!(b,10)
ERROR: ArgumentError: Accumulator and term have inconsistent shapes
Stacktrace:
[1] _push_acc!(::Base.HasShape{1},::Accumulator{Array{Int64,1}},::Int64) at ...
[2] push!(::Accumulator{Array{Int64,::Int64) at ...
[3] top-level scope at REPL[7]:1
修补方法签名
除了使用迭代器特征之外,我们还可以对您的push!
方法签名进行一些小的调整,以防止将数组推入标量。
mutable struct Accumulator{T}
data::T
count::Int64
end
function Base.push!(acc::Accumulator,term)
acc.data += term
acc.count += 1
acc
end
function Base.push!(acc::Accumulator{T},term::AbstractArray) where {T <: AbstractArray}
acc.data .+= term
acc.count += 1
acc
end
function Base.push!(::Accumulator,::AbstractArray)
throw(ArgumentError("Can't push an array onto a scalar"))
end
现在,如果我们尝试将数组推入标量,则会收到一条明智的错误消息:
julia> a = Accumulator(42,0)
Accumulator{Int64}(42,[1,2])
ERROR: ArgumentError: Can't push an array onto a scalar