在Julia中对标量和数组进行元素明智的就地操作是否存在统一的语法?

问题描述

考虑以下累加器类型,它的工作方式类似于数组,您可以将其压入其中,但只能跟踪其均值:

mutable struct Accumulator{T}
    data::T
    count::Int64
end

function Base.push!(acc::Accumulator,term)
    acc.data += term       # <-- in-place addition
    acc.count += 1
    acc
end

mean(acc::Accumulator) = acc.data ./ acc.count

我希望它适用于T是标量或数组类型。然而, 事实证明,对于T是数组类型,push!中的加法创建了一个临时类型。这是因为在Julia中,x+=aequivalentx=x+a,而且我怀疑Julia无法保证acc.dataterm不具有别名。

一个简单的解决方法是将+=替换为逐元素加法.+=。但是,这将破坏标量类型,不允许这样做。因此,我想出解决此问题的唯一方法添加以下形式的特殊化:

function Base.push!(acc::Accumulator,term::AbstractArray)
    acc.data .+= term       # <-- element-wise addition
    acc.count += 1
    acc
end

这有点丑陋,而且很脆弱……有人知道更好的方法吗,最好是通用的方法,而无需临时创建?

解决方法

奇怪的是,Numbers are iterable in Julia,但这似乎无济于事,因为setindex! s没有Number方法。

这是两种不同的方法。第一个使用iterator traits,第二个使用一些方法签名来解决极端情况。

迭代器特征

我们可以使用IteratorSize特征来区分标量和向量。对于标量,Base.IteratorSize(x)返回Base.HasShape{0}。对于数组,Base.IteratorSize(x)返回Base.HasShape{N},其中N是数组的维数。

mutable struct Accumulator{T}
    data::T
    count::Int64
end

function Base.push!(acc::Accumulator{T},term::S) where {T,S}
    _push_acc!(Base.IteratorSize(T),Base.IteratorSize(S),acc,term)
end

function _push_acc!(::Base.HasShape{0},::Base.HasShape{0},acc::Accumulator,term)
    acc.data += term
    acc.count += 1
    acc
end

function _push_acc!(::Base.HasShape{N},::Base.HasShape{N},term) where {N}
    acc.data .+= term
    acc.count += 1
    acc
end

function _push_acc!(::Base.HasShape{M},::Accumulator,::Any) where {M,N}
    throw(ArgumentError("Accumulator and term have inconsistent shapes"))
end

在REPL上采取的行动:

julia> a = Accumulator(1,0)
Accumulator{Int64}(1,0)

julia> b = Accumulator([1,2],0)
Accumulator{Array{Int64,1}}([1,0)

julia> push!(a,42)
Accumulator{Int64}(43,1)

julia> push!(b,[3,4])
Accumulator{Array{Int64,1}}([4,6],1)

julia> push!(a,[5,6])
ERROR: ArgumentError: Accumulator and term have inconsistent shapes
Stacktrace:
 [1] _push_acc!(::Base.HasShape{0},::Base.HasShape{1},::Accumulator{Int64},::Array{Int64,1}) at ...
 [2] push!(::Accumulator{Int64},1}) at ...
 [3] top-level scope at REPL[6]:1

julia> push!(b,10)
ERROR: ArgumentError: Accumulator and term have inconsistent shapes
Stacktrace:
 [1] _push_acc!(::Base.HasShape{1},::Accumulator{Array{Int64,1}},::Int64) at ...
 [2] push!(::Accumulator{Array{Int64,::Int64) at ...
 [3] top-level scope at REPL[7]:1

修补方法签名

除了使用迭代器特征之外,我们还可以对您的push!方法签名进行一些小的调整,以防止将数组推入标量。

mutable struct Accumulator{T}
    data::T
    count::Int64
end

function Base.push!(acc::Accumulator,term)
    acc.data += term
    acc.count += 1
    acc
end

function Base.push!(acc::Accumulator{T},term::AbstractArray) where {T <: AbstractArray}
    acc.data .+= term
    acc.count += 1
    acc
end

function Base.push!(::Accumulator,::AbstractArray)
    throw(ArgumentError("Can't push an array onto a scalar"))
end

现在,如果我们尝试将数组推入标量,则会收到一条明智的错误消息:

julia> a = Accumulator(42,0)
Accumulator{Int64}(42,[1,2])
ERROR: ArgumentError: Can't push an array onto a scalar