Swift 中的 SIMD 操作对于跨步数组的混合操作很慢

问题描述

我想在我的 swift 代码中计算跨距数组的符号函数(例如 vDSP 函数)。为此,我使用了 simd 框架。但是,正如您在下面的代码中看到的那样,我的代码可能有点多余。 (实际上,尽管使用了 SIMD!,但这段代码比纯 swift 实现(0.03s)慢(0.06s))。具体来说,我认为 SIMD64<Float> 初始化(a)和复制目标指针(b)是瓶颈部分(见下面的代码)。

注意,我对 strided 数组的 10e6 个元素的比较。

所以,我的问题是;

  • 我应该如何以最快的方式从跨步数组(即 SIMD64<Float>UnsafePointer<Float>)初始化 vDSP_Stride
  • 如何使用 SIMD64<Float> 将计算出的 UnsafeMutablePointer<Float> 传递给 vDSP_Stride
for (n = 0; n < N; ++n)
   if (A[n*IA] > 0)
       *D[n*ID] = *B;
   else if (A[n*IA] < 0)
       *D[n*ID] = *C;
   else
       *D[n*ID] = 0;

我的 SIMD 代码

import Accelerate
import simd

/**
 Apply sign function for the given single-precision vector.
 
 - Parameters:
    - __A: Single-precision real input vector
    - __IA: Stride for A
    - __B: Pointer to single-precision real input scalar: upper destination
    - __C: Pointer to single-precision real input scalar: lower destination
    - __D: Single-precision real output vector
    -  __ID: Stride for D
    - __N: The number of elements to process
 
 */
func evDSP_sign(_ __A: UnsafePointer<Float>,_ __IA: vDSP_Stride,_ __B: UnsafePointer<Float>,_ __C: UnsafePointer<Float>,_ __D: UnsafeMutablePointer<Float>,_ __ID: vDSP_Stride,_ __N: vDSP_Length){
    var __A = UnsafeMutablePointer(mutating: __A)
    var __D = UnsafeMutablePointer(mutating: __D)
   let strideA = Int(__IA)
   let strideD = Int(__ID)
   
    let (iterations,remaining) = (Int(__N) / 64,Int(__N) % 64)
    let zeros = SIMD64<Float>(repeating: 0)
    for _ in 0..<iterations{
        ///////////////////// (a) ///////////////////////////////////////////
        var a = SIMD64<Float>(__A.pointee,(__A + __IA).pointee,(__A + 2*__IA).pointee,(__A + 3*__IA).pointee,(__A + 4*__IA).pointee,(__A + 5*__IA).pointee,(__A + 6*__IA).pointee,(__A + 7*__IA).pointee,(__A + 8*__IA).pointee,(__A + 9*__IA).pointee,(__A + 10*__IA).pointee,(__A + 11*__IA).pointee,(__A + 12*__IA).pointee,(__A + 13*__IA).pointee,(__A + 14*__IA).pointee,(__A + 15*__IA).pointee,(__A + 16*__IA).pointee,(__A + 17*__IA).pointee,(__A + 18*__IA).pointee,(__A + 19*__IA).pointee,(__A + 20*__IA).pointee,(__A + 21*__IA).pointee,(__A + 22*__IA).pointee,(__A + 23*__IA).pointee,(__A + 24*__IA).pointee,(__A + 25*__IA).pointee,(__A + 26*__IA).pointee,(__A + 27*__IA).pointee,(__A + 28*__IA).pointee,(__A + 29*__IA).pointee,(__A + 30*__IA).pointee,(__A + 31*__IA).pointee,(__A + 32*__IA).pointee,(__A + 33*__IA).pointee,(__A + 34*__IA).pointee,(__A + 35*__IA).pointee,(__A + 36*__IA).pointee,(__A + 37*__IA).pointee,(__A + 38*__IA).pointee,(__A + 39*__IA).pointee,(__A + 40*__IA).pointee,(__A + 41*__IA).pointee,(__A + 42*__IA).pointee,(__A + 43*__IA).pointee,(__A + 44*__IA).pointee,(__A + 45*__IA).pointee,(__A + 46*__IA).pointee,(__A + 47*__IA).pointee,(__A + 48*__IA).pointee,(__A + 49*__IA).pointee,(__A + 50*__IA).pointee,(__A + 51*__IA).pointee,(__A + 52*__IA).pointee,(__A + 53*__IA).pointee,(__A + 54*__IA).pointee,(__A + 55*__IA).pointee,(__A + 56*__IA).pointee,(__A + 57*__IA).pointee,(__A + 58*__IA).pointee,(__A + 59*__IA).pointee,(__A + 60*__IA).pointee,(__A + 61*__IA).pointee,(__A + 62*__IA).pointee,(__A + 63*__IA).pointee)
        
        a.replace(with: __B.pointee,where: a .> zeros)
        a.replace(with: __C.pointee,where: a .< zeros)
       
       ///////////////////// (b) ///////////////////////////////////////////
       withUnsafePointer(to: &a){
           ptr in
           ptr.withMemoryRebound(to: Float.self,capacity: 64){
               cblas_scopy(Int32(64),$0,Int32(1),__D,Int32(__ID))
           }
       }
       // proceeds offset
       __A += 64*strideA
       __D += 64*strideD

    }
    // remaining
   for _ in 0..<remaining{
       var tmp: Float
       if __A.pointee > 0{
           tmp = __B.pointee
       }
       else if __A.pointee < 0{
           tmp = __C.pointee
       }
       else{
           tmp = .zero
       }
       __D.assign(from: &tmp,count: 1)
       __A += strideA
       __D += strideD
   }
    
}

纯swift代码; 注意 withContiguousDataUnsafeMPtrT 是我获取跨步数组元素的方法

func _sign<T: MfStorable>(ret: MyStriderdArrayClass,low: T,high: T) -> MyStriderdArrayClass{
            ret.withContiguousDataUnsafeMPtrT(datatype: T.self){
                if $0.pointee > .zero{
                    $0.pointee = high
                }
                else if $0.pointee < .zero{
                    $0.pointee = low
                }
            }
            return ret
        }

解决方法

暂无找到可以解决该程序问题的有效方法,小编努力寻找整理中!

如果你已经找到好的解决方法,欢迎将解决方案带上本链接一起发送给小编。

小编邮箱:dio#foxmail.com (将#修改为@)