变换推力矢量的类型

问题描述

我对 CUDA 和推力有点陌生,目前我正在努力解决以下问题:

我有 2 个结构体携带数据。

struct S1 {
    unsigned long A,B,C;
}

struct S2 {
    double A,C;
}

在程序开始时,我有第一个结构体的向量,我想使用 GPU 以及一个特殊的函子将 S1 的向量转换为 S2 的向量。结果将与输入的大小相同,但只有不同类型的元素。

我目前正在使用此设置:

struct Functor {
    Functor() {}

    __host__ __ device__ S2 operator() (const S1& s1) {
        // perform some operation on S1 to convert it so S2
        // here I just copy the values
        return S2{s1.A,s1.B,s1.C};
    }
}

void main() {
    std::vector<S1> input;
    // fill the 'input' vector

    // move the input to a device vector
    thrust::device_vector<S1> input_d(input);
    // empty vector for struct S2
    thrust::device_vector<S2> output_d(input.size());

    thrust::transform(input_d.begin(),input_d.end(),output_d.begin(),output_d.end(),Functor());

    return 0;
}

函子 Functor 负责将 S1 转换为 S2(在本例中简化)。

此代码导致编译器错误,因为 operator() 需要 2 个参数,而我只想有一个输入。环顾网络,我没有找到适用于我的问题的解决方案。

解决方法

正如评论中已经指出的那样,我误读了文档(参见 here)。解决方案是使用一元函子和另一种调用 thrust::transform 的方式。

// this functor converts values of S1 to S1 by multiplying them with .5
// in the actual scenario this method does perform much more useful operations
struct Functor : public thrust::unary_function<S1,S2> {
    Functor() {}

    __host__ __device__ S2 operator() (const S1& s1) const {
        // very basic operations on the values of s1
        double a = s1.A * .5;
        double b = s1.B * .5;
        double c = s1.C * .5;
        return S2{a,b,c};
    }
}

然后将转换称为:

thrust::host_vector<S1> tmp;
// fill the tmp vector with S1 instances...

// move the host vector to the device
thrust::device_vector<S1> input_d;
input_d = tmp; 

// initializing the output with the same size as the input
thrust::device_vector<S2> output_d(tmp.size()); 

// calling the functor on all elements of the input and store them in the output
thrust::transform(input_d.begin(),input_d.end(),output_d.begin(),Functor()); 

编辑:添加了更多代码部分,因此它实际上是工作代码。

相关问答

错误1:Request method ‘DELETE‘ not supported 错误还原:...
错误1:启动docker镜像时报错:Error response from daemon:...
错误1:private field ‘xxx‘ is never assigned 按Alt...
报错如下,通过源不能下载,最后警告pip需升级版本 Requirem...