问题描述
我在 C++ 中有一个相当大的特征矩阵,但现在我希望每列中只有 6 个不等于 0 的值。 因此,我想将所有值都设置为 0,除了最大的 6。 我不知道我该怎么做,有人能帮我吗?
解决方法
我愿意这样做。
void trimMatrix( Eigen::MatrixXd& matrix )
{
constexpr size_t elementsToKeep = 6;
std::vector<uint32_t> vec;
vec.reserve( matrix.cols() );
for( ptrdiff_t col = 0; col < matrix.cols(); col++ )
{
// Collect non-zero elements from a column of the matrix
vec.clear();
// BTW,when the matrix is sparse and column major,Eigen has a faster way to iterate over non-zero elements.
for( ptrdiff_t r = 0; r < matrix.rows(); r++ )
{
double e = matrix( r,col );
if( e != 0.0 )
vec.push_back( (uint32_t)r );
}
if( vec.size() <= elementsToKeep )
continue; // Not enough non zero elements,nothing to do for the column.
// Partition the vector into 2 sorted pieces.
// Standard library has an algorithm for such partition,faster than sorting.
// BTW the code is only good for column major matrices.
// For row major ones RAM access pattern is bad,need another way.
std::nth_element( vec.begin(),vec.begin() + elementsToKeep,vec.end(),[&matrix,col]( uint32_t a,uint32_t b )
{
const double e1 = matrix( a,col );
const double e2 = matrix( b,col );
// Using `>` for order because we want top N elements before elementsToKeep
return std::abs( e1 ) > std::abs( e2 );
} );
// Zero out elements outside of the top N
for( auto it = vec.begin() + elementsToKeep; it != vec.end(); it++ )
matrix( *it,col ) = 0.0;
}
}
,
//std
#include <iostream>
#include <vector>
#include <numeric>
//eigen
#include <Eigen/Dense>
using namespace std;
int main()
{
Eigen::MatrixXf m(4,4);
m << 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16;
// Take the 2nd column
auto col1 = m.col(1);
// Generate a sequence of indices for each value in the column
std::vector<size_t> idx(m.rows(),0);
std::iota(idx.begin(),idx.end(),0);
// Sort the indices according to their respective value
std::sort(idx.begin(),[&col1](auto& lhv,auto& rhv){ return col1(lhv) < col1(rhv); });
// Ignore the last 2 (so,the 2 biggest). Or 6 in your case.
idx.resize(idx.size() - 2);
// Set the rest to 0
for(auto id: idx) {
col1(id) = 0;
}
cout << m << endl;
// Output :
// 1 0 3 4
// 5 0 7 8
// 9 10 11 12
// 13 14 15 16
}