问题描述
我正在编写一个函数,该函数从稀疏向量中提取前 x 值(如果小于 x,则值更少)。我想像许多函数一样包含一个“就地”选项,如果选项为 True,它会删除最高值,如果选项为 False,则保留它们。
我的问题是我当前的函数正在覆盖输入向量,而不是保持原样。我不确定为什么会发生这种情况。我希望解决我的设计问题的方法是包含一个 if 语句,该语句将使用 copy.copy() 复制输入,但这会引发值错误 (ValueError: row index exceeded matrix维数),这对我来说没有意义.
代码:
from scipy.sparse import csr_matrix
import copy
max_loc=20
data=[1,3,2,5]
rows=[0]*len(data)
indices=[4,8,12,7]
sparse_test=csr_matrix((data,(rows,indices)),shape=(1,max_loc))
print(sparse_test)
def top_x_in_sparse(in_vect,top_x,inplace=False):
if inplace==True:
rvect=in_vect
else:
rvect=copy.copy(in_vect)
newmax=top_x
count=0
out_list=[]
while newmax>0:
newmax=1
if count<top_x:
out_list+=[csr_matrix.max(rvect)]
remove=csr_matrix.argmax(rvect)
rvect[0,remove]=0
rvect.eliminate_zeros()
newmax=csr_matrix.max(rvect)
count+=1
else:
newmax=0
return out_list
a=top_x_in_sparse(sparse_test,3)
print(a)
print(sparse_test)
我的问题有两个部分:
解决方法
你真的只是不想循环周期。每次使用 .eliminate_zeros()
进行循环迭代时重新分配这些数组是最慢的事情,但不是不这样做的唯一原因。
import numpy as np
from scipy.sparse import csr_matrix
max_loc=20
data=[1,3,2,5]
rows=[0]*len(data)
indices=[4,8,12,7]
sparse_test=csr_matrix((data,(rows,indices)),shape=(1,max_loc))
这样的东西会更好:
def top_x_in_sparse(in_vect,top_x,inplace=False):
n = len(in_vect.data)
if top_x >= n:
if inplace:
out_data = in_vect.data.tolist()
in_vect.data = np.array([],dtype=in_vect.data.dtype)
in_vect.indices = np.array([],dtype=in_vect.indices.dtype)
in_vect.indptr = np.array([0,0],dtype=in_vect.indptr.dtype)
return out_data
else:
return in_vect.data.tolist()
else:
k = n - top_x
partition_idx = np.argpartition(in_vect.data,k)
if inplace:
out_data = in_vect.data[partition_idx[k:n]].tolist()
in_vect.data = in_vect.data[partition_idx[0:k]]
in_vect.indices = in_vect.indices[partition_idx[0:k]]
in_vect.indptr = np.array([0,len(in_vect.data)],dtype=in_vect.indptr.dtype)
return out_data
else:
return in_vect.data[partition_idx[k:n]].tolist()
如果您需要对返回的值进行排序,您当然也可以这样做。
a=top_x_in_sparse(sparse_test,inplace=False)
>>> print(a)
[3,5,3]
>>> print(sparse_test)
(0,2) 3
(0,4) 1
(0,7) 5
(0,8) 3
(0,12) 2
b=top_x_in_sparse(sparse_test,inplace=True)
>>> print(b)
[3,12) 2
同样根据您在评论中提出的问题:稀疏数组对象的浅拷贝不会复制保存数据的 numpy 数组。稀疏对象只有对这些对象的引用。深拷贝会得到它们,但使用内置的拷贝方法已经知道哪些被引用的东西需要被拷贝,哪些不需要。